diff --git a/pom.xml b/pom.xml index f6cb92ca9..ff81fb87b 100644 --- a/pom.xml +++ b/pom.xml @@ -77,7 +77,7 @@ 1.7.30 2.17.1 4.13.2 - 5.7.0 + 5.10.1 1.18.22 4.10.0 3.0.24 @@ -89,11 +89,12 @@ 3.0.0-M2 2.19.1 1.1.0 - 5.1.0 + 5.10.1 2.12.7.1 2.10.1 1.6.20 1.2.83 + 5.8.0 @@ -109,6 +110,18 @@ + + org.mockito + mockito-core + ${mockito.version} + test + + + org.mockito + mockito-junit-jupiter + ${mockito.version} + test + io.grpc grpc-netty-shaded diff --git a/src/main/java/io/milvus/v2/client/ConnectConfig.java b/src/main/java/io/milvus/v2/client/ConnectConfig.java new file mode 100644 index 000000000..9c8024a79 --- /dev/null +++ b/src/main/java/io/milvus/v2/client/ConnectConfig.java @@ -0,0 +1,60 @@ +package io.milvus.v2.client; + +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.concurrent.TimeUnit; + +@Data +@SuperBuilder +public class ConnectConfig { + private String uri; + private String token; + private String username; + private String password; + private String databaseName; + @Builder.Default + private long connectTimeoutMs = 10000; + @Builder.Default + private long keepAliveTimeMs = 55000; + @Builder.Default + private long keepAliveTimeoutMs = 20000; + @Builder.Default + private boolean keepAliveWithoutCalls = false; + @Builder.Default + private long rpcDeadlineMs = 0; // Disabling deadline + + private String clientKeyPath; + private String clientPemPath; + private String caPemPath; + private String serverPemPath; + private String serverName; + + @Builder.Default + private boolean secure = true; + @Builder.Default + private long idleTimeoutMs = TimeUnit.MILLISECONDS.convert(24, TimeUnit.HOURS); + + public String getHost() { + URI uri = URI.create(this.uri); + return uri.getHost(); + } + + public int getPort() { + URI uri = URI.create(this.uri); + return uri.getPort(); + } + + public String getAuthorization() { + if (token != null) { + return token; + }else if (username != null && password != null) { + return username + ":" + password; + } + return null; + } +} diff --git a/src/main/java/io/milvus/v2/client/MilvusClientV2.java b/src/main/java/io/milvus/v2/client/MilvusClientV2.java new file mode 100644 index 000000000..b167db5b0 --- /dev/null +++ b/src/main/java/io/milvus/v2/client/MilvusClientV2.java @@ -0,0 +1,521 @@ +package io.milvus.v2.client; + +import io.grpc.ManagedChannel; +import io.milvus.grpc.MilvusServiceGrpc; +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.service.collection.CollectionService; +import io.milvus.v2.service.collection.request.*; +import io.milvus.v2.service.collection.response.DescribeCollectionResp; +import io.milvus.v2.service.collection.response.ListCollectionsResp; +import io.milvus.v2.service.index.IndexService; +import io.milvus.v2.service.index.request.CreateIndexReq; +import io.milvus.v2.service.index.request.DescribeIndexReq; +import io.milvus.v2.service.index.request.DropIndexReq; +import io.milvus.v2.service.index.response.DescribeIndexResp; +import io.milvus.v2.service.partition.PartitionService; +import io.milvus.v2.service.partition.request.*; +import io.milvus.v2.service.rbac.RoleService; +import io.milvus.v2.service.rbac.UserService; +import io.milvus.v2.service.rbac.request.*; +import io.milvus.v2.service.rbac.response.DescribeRoleResp; +import io.milvus.v2.service.rbac.response.DescribeUserResp; +import io.milvus.v2.service.utility.UtilityService; +import io.milvus.v2.service.utility.request.AlterAliasReq; +import io.milvus.v2.service.utility.request.CreateAliasReq; +import io.milvus.v2.service.utility.request.DropAliasReq; +import io.milvus.v2.service.utility.request.FlushReq; +import io.milvus.v2.service.vector.VectorService; +import io.milvus.v2.service.vector.request.*; +import io.milvus.v2.service.vector.response.GetResp; +import io.milvus.v2.service.vector.response.QueryResp; +import io.milvus.v2.service.vector.response.SearchResp; +import io.milvus.v2.utils.ClientUtils; +import lombok.NonNull; +import lombok.Setter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.concurrent.TimeUnit; + +public class MilvusClientV2 { + private static final Logger logger = LoggerFactory.getLogger(MilvusClientV2.class); + private ManagedChannel channel; + @Setter + private MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub; + private final ClientUtils clientUtils = new ClientUtils(); + private final CollectionService collectionService = new CollectionService(); + private final IndexService indexService = new IndexService(); + private final VectorService vectorService = new VectorService(); + private final PartitionService partitionService = new PartitionService(); + private final UserService userService = new UserService(); + private final RoleService roleService = new RoleService(); + private final UtilityService utilityService = new UtilityService(); + private ConnectConfig connectConfig; + + /** + * Creates a Milvus client instance. + * @param connectConfig Milvus server connection configuration + */ + public MilvusClientV2(ConnectConfig connectConfig) { + if (connectConfig != null) { + connect(connectConfig); + } + } + /** + * connect to Milvus server + * + * @param connectConfig Milvus server connection configuration + */ + private void connect(ConnectConfig connectConfig){ + this.connectConfig = connectConfig; + try { + if(this.channel != null) { + // close channel first + close(3); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + channel = clientUtils.getChannel(connectConfig); + + if (connectConfig.getRpcDeadlineMs() > 0) { + blockingStub = MilvusServiceGrpc.newBlockingStub(channel).withWaitForReady() + .withDeadlineAfter(connectConfig.getRpcDeadlineMs(), TimeUnit.MILLISECONDS); + }else { + blockingStub = MilvusServiceGrpc.newBlockingStub(channel); + } + + if (connectConfig.getDatabaseName() != null) { + // check if database exists + clientUtils.checkDatabaseExist(this.blockingStub, connectConfig.getDatabaseName()); + } + } + + /** + * use Database + * @param dbName databaseName + */ + public void useDatabase(@NonNull String dbName) { + // check if database exists + clientUtils.checkDatabaseExist(this.blockingStub, dbName); + try { + this.connectConfig.setDatabaseName(dbName); + this.close(3); + this.connect(this.connectConfig); + }catch (InterruptedException e){ + logger.error("close connect error"); + } + } + + //Collection Operations + /** + * Fast Creates a collection in Milvus. + * + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R createCollection(CreateCollectionReq request) { + return collectionService.createCollection(this.blockingStub, request); + } + + /** + * Creates a collection with Schema in Milvus. + * + * @param request create collection request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R createCollectionWithSchema(CreateCollectionWithSchemaReq request) { + return collectionService.createCollectionWithSchema(this.blockingStub, request); + } + + /** + * list milvus collections + * + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R listCollections(){ + return collectionService.listCollections(this.blockingStub); + } + + /** + * Drops a collection in Milvus. + * + * @param request drop collection request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R dropCollection(DropCollectionReq request){ + return collectionService.dropCollection(this.blockingStub, request); + } + /** + * Checks whether a collection exists in Milvus. + * + * @param request has collection request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R hasCollection(HasCollectionReq request){ + return collectionService.hasCollection(this.blockingStub, request); + } + /** + * Gets the collection info in Milvus. + * + * @param request describe collection request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R describeCollection(DescribeCollectionReq request){ + return collectionService.describeCollection(this.blockingStub, request); + } + /** + * get collection stats for a collection in Milvus. + * + * @param request get collection stats request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ +// public R getCollectionStats(GetCollectionStatsReq request){ +// return collectionService.getCollectionStats(this.blockingStub, request); +// } + /** + * rename collection in a collection in Milvus. + * + * @param request rename collection request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R renameCollection(RenameCollectionReq request){ + return collectionService.renameCollection(this.blockingStub, request); + } + /** + * Loads a collection into memory in Milvus. + * + * @param request load collection request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R loadCollection(LoadCollectionReq request){ + return collectionService.loadCollection(this.blockingStub, request); + } + /** + * Releases a collection from memory in Milvus. + * + * @param request release collection request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R releaseCollection(ReleaseCollectionReq request){ + return collectionService.releaseCollection(this.blockingStub, request); + } + /** + * Checks whether a collection is loaded in Milvus. + * + * @param request get load state request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R getLoadState(GetLoadStateReq request){ + return collectionService.getLoadState(this.blockingStub, request); + } + + //Index Operations + /** + * Creates an index for a specified field in a collection in Milvus. + * + * @param request create index request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R createIndex(CreateIndexReq request){ + return indexService.createIndex(this.blockingStub, request); + } + /** + * Drops an index for a specified field in a collection in Milvus. + * + * @param request drop index request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R dropIndex(DropIndexReq request){ + return indexService.dropIndex(this.blockingStub, request); + } + /** + * Checks whether an index exists for a specified field in a collection in Milvus. + * + * @param request describe index request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R describeIndex(DescribeIndexReq request){ + return indexService.describeIndex(this.blockingStub, request); + } + + // Vector Operations + + /** + * Inserts vectors into a collection in Milvus. + * + * @param request insert request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R insert(InsertReq request){ + return vectorService.insert(this.blockingStub, request); + } + /** + * Upsert vectors into a collection in Milvus. + * + * @param request upsert request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R upsert(UpsertReq request){ + return vectorService.upsert(this.blockingStub, request); + } + /** + * Deletes vectors in a collection in Milvus. + * + * @param request delete request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R delete(DeleteReq request){ + return vectorService.delete(this.blockingStub, request); + } + /** + * Gets vectors in a collection in Milvus. + * + * @param request get request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R get(GetReq request){ + return vectorService.get(this.blockingStub, request); + } + + /** + * Queries vectors in a collection in Milvus. + * + * @param request query request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R query(QueryReq request){ + return vectorService.query(this.blockingStub, request); + } + /** + * Searches vectors in a collection in Milvus. + * + * @param request search request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R search(SearchReq request){ + return vectorService.search(this.blockingStub, request); + } + + // Partition Operations + /** + * Creates a partition in a collection in Milvus. + * + * @param request create partition request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R createPartition(CreatePartitionReq request) { + return partitionService.createPartition(this.blockingStub, request); + } + + /** + * Drops a partition in a collection in Milvus. + * + * @param request drop partition request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R dropPartition(DropPartitionReq request) { + return partitionService.dropPartition(this.blockingStub, request); + } + + /** + * Checks whether a partition exists in a collection in Milvus. + * + * @param request has partition request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R hasPartition(HasPartitionReq request) { + return partitionService.hasPartition(this.blockingStub, request); + } + + /** + * Lists all partitions in a collection in Milvus. + * + * @param request list partitions request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R> listPartitions(ListPartitionsReq request) { + return partitionService.listPartitions(this.blockingStub, request); + } + + /** + * Loads partitions in a collection in Milvus. + * + * @param request load partitions request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R loadPartitions(LoadPartitionsReq request) { + return partitionService.loadPartitions(this.blockingStub, request); + } + /** + * Releases partitions in a collection in Milvus. + * + * @param request release partitions request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R releasePartitions(ReleasePartitionsReq request) { + return partitionService.releasePartitions(this.blockingStub, request); + } + // rbac operations + // user operations + /** + * list users + * + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R> listUsers(){ + return userService.listUsers(this.blockingStub); + } + /** + * describe user + * + * @param request describe user request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R describeUser(DescribeUserReq request){ + return userService.describeUser(this.blockingStub, request); + } + /** + * create user + * + * @param request create user request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R createUser(CreateUserReq request){ + return userService.createUser(this.blockingStub, request); + } + /** + * change password + * + * @param request change password request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R updatePassword(UpdatePasswordReq request) { + return userService.updatePassword(this.blockingStub, request); + } + /** + * drop user + * + * @param request drop user request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R dropUser(DropUserReq request){ + return userService.dropUser(this.blockingStub, request); + } + // role operations + /** + * list roles + * + * @return {status:result code, data:List{msg: result message}} + */ + public R> listRoles() { + return roleService.listRoles(this.blockingStub); + } + /** + * describe role + * + * @param request describe role request + * @return {status:result code, data:DescribeRoleResp{msg: result message}} + */ + public R describeRole(DescribeRoleReq request) { + return roleService.describeRole(this.blockingStub, request); + } + /** + * create role + * + * @param request create role request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R createRole(CreateRoleReq request) { + return roleService.createRole(this.blockingStub, request); + } + /** + * drop role + * + * @param request drop role request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R dropRole(DropRoleReq request) { + return roleService.dropRole(this.blockingStub, request); + } + /** + * grant privilege + * + * @param request grant privilege request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R grantPrivilege(GrantPrivilegeReq request) { + return roleService.grantPrivilege(this.blockingStub, request); + } + /** + * revoke privilege + * + * @param request revoke privilege request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R revokePrivilege(RevokePrivilegeReq request) { + return roleService.revokePrivilege(this.blockingStub, request); + } + /** + * grant role + * + * @param request grant role request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R grantRole(GrantRoleReq request) { + return roleService.grantRole(this.blockingStub, request); + } + /** + * revoke role + * + * @param request revoke role request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R revokeRole(RevokeRoleReq request) { + return roleService.revokeRole(this.blockingStub, request); + } + + // Utility Operations + + /** + * create aliases + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R createAlias(CreateAliasReq request) { + return utilityService.createAlias(this.blockingStub, request); + } + /** + * drop aliases + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R dropAlias(DropAliasReq request) { + return utilityService.dropAlias(this.blockingStub, request); + } + /** + * alter aliases + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R alterAlias(AlterAliasReq request) { + return utilityService.alterAlias(this.blockingStub, request); + } + /** + * flush collection + * @param request flush request + * @return {status:result code, data:RpcStatus{msg: result message}} + */ + public R flush(FlushReq request) { + return utilityService.flush(this.blockingStub, request); + } + /** + * close client + * + * @param maxWaitSeconds max wait seconds + */ + public void close(long maxWaitSeconds) throws InterruptedException { + if(channel!= null){ + channel.shutdownNow(); + channel.awaitTermination(maxWaitSeconds, TimeUnit.SECONDS); + } + } +} diff --git a/src/main/java/io/milvus/v2/common/ConsistencyLevel.java b/src/main/java/io/milvus/v2/common/ConsistencyLevel.java new file mode 100644 index 000000000..c37b2a953 --- /dev/null +++ b/src/main/java/io/milvus/v2/common/ConsistencyLevel.java @@ -0,0 +1,17 @@ +package io.milvus.v2.common; + +import io.milvus.common.clientenum.ConsistencyLevelEnum; +import lombok.Getter; +@Getter +public enum ConsistencyLevel{ + STRONG("Strong", 0), + BOUNDED("Bounded", 2), + EVENTUALLY("Eventually",3), + ; + private final String name; + private final int code; + ConsistencyLevel(String name, int code) { + this.name = name; + this.code = code; + } +} diff --git a/src/main/java/io/milvus/v2/common/DataType.java b/src/main/java/io/milvus/v2/common/DataType.java new file mode 100644 index 000000000..8c3d538ce --- /dev/null +++ b/src/main/java/io/milvus/v2/common/DataType.java @@ -0,0 +1,32 @@ +package io.milvus.v2.common; + +import lombok.Getter; + +@Getter +public enum DataType { + None(0), + Bool(1), + Int8(2), + Int16(3), + Int32(4), + Int64(5), + + Float(10), + Double(11), + + String(20), + VarChar(21), // variable-length strings with a specified maximum length + Array(22), + JSON(23), + + BinaryVector(100), + FloatVector(101), + Float16Vector(102), + BFloat16Vector(103); + + private final int code; + DataType(int code) { + this.code = code; + } + ; +} diff --git a/src/main/java/io/milvus/v2/common/IndexParam.java b/src/main/java/io/milvus/v2/common/IndexParam.java new file mode 100644 index 000000000..c7770d37a --- /dev/null +++ b/src/main/java/io/milvus/v2/common/IndexParam.java @@ -0,0 +1,84 @@ +package io.milvus.v2.common; + +import lombok.Builder; +import lombok.Data; +import lombok.Getter; +import lombok.NonNull; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class IndexParam { + @NonNull + private String fieldName; + private String indexName; + @Builder.Default + private IndexType indexType = IndexType.AUTOINDEX; + private MetricType metricType; + + public String getIndexName() { + if(indexName == null) { + return fieldName; + } + return indexName; + } + + public enum MetricType { + INVALID, + // Only for float vectors + L2, + IP, + COSINE, + + // Only for binary vectors + HAMMING, + JACCARD, + ; + } + + @Getter + public enum IndexType { + INVALID, + //Only supported for float vectors + FLAT(1), + IVF_FLAT(2), + IVF_SQ8(3), + IVF_PQ(4), + HNSW(5), + DISKANN(10), + AUTOINDEX(11), + SCANN(12), + + // GPU index + GPU_IVF_FLAT(50), + GPU_IVF_PQ(51), + + //Only supported for binary vectors + BIN_FLAT(80), + BIN_IVF_FLAT(81), + + //Scalar field index start from here + //Only for varchar type field + TRIE("Trie", 100), + //Only for scalar type field + STL_SORT(200), + ; + private final String name; + private final int code; + + IndexType(){ + this.name = this.toString(); + this.code = this.ordinal(); + } + + IndexType(int code){ + this.name = this.toString(); + this.code = code; + } + + IndexType(String name, int code){ + this.name = name; + this.code = code; + } + } +} diff --git a/src/main/java/io/milvus/v2/service/BaseService.java b/src/main/java/io/milvus/v2/service/BaseService.java new file mode 100644 index 000000000..6560fab97 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/BaseService.java @@ -0,0 +1,27 @@ +package io.milvus.v2.service; + +import io.milvus.exception.MilvusException; +import io.milvus.grpc.BoolResponse; +import io.milvus.grpc.HasCollectionRequest; +import io.milvus.grpc.MilvusServiceGrpc; +import io.milvus.param.R; +import io.milvus.v2.utils.ConvertUtils; +import io.milvus.v2.utils.DataUtils; +import io.milvus.v2.utils.VectorUtils; +import io.milvus.v2.utils.RpcUtils; + +public class BaseService { + public RpcUtils rpcUtils = new RpcUtils(); + public DataUtils dataUtils = new DataUtils(); + public VectorUtils vectorUtils = new VectorUtils(); + public ConvertUtils convertUtils = new ConvertUtils(); + + protected void checkCollectionExist(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String collectionName) { + HasCollectionRequest request = HasCollectionRequest.newBuilder().setCollectionName(collectionName).build(); + BoolResponse result = blockingStub.hasCollection(request); + rpcUtils.handleResponse("", result.getStatus()); + if (!result.getValue() == Boolean.TRUE) { + throw new MilvusException("Collection " + collectionName + " not exist", R.Status.CollectionNotExists.getCode()); + } + } +} diff --git a/src/main/java/io/milvus/v2/service/collection/CollectionService.java b/src/main/java/io/milvus/v2/service/collection/CollectionService.java new file mode 100644 index 000000000..906e638ff --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/CollectionService.java @@ -0,0 +1,218 @@ +package io.milvus.v2.service.collection; + +import io.milvus.grpc.*; +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.service.BaseService; +import io.milvus.v2.service.collection.request.*; +import io.milvus.v2.service.collection.response.DescribeCollectionResp; +import io.milvus.v2.service.collection.response.GetCollectionStatsResp; +import io.milvus.v2.service.collection.response.ListCollectionsResp; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.service.index.IndexService; +import io.milvus.v2.service.index.request.CreateIndexReq; +import io.milvus.v2.utils.SchemaUtils; + +public class CollectionService extends BaseService { + public IndexService indexService = new IndexService(); + + public R createCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreateCollectionReq request) { + String title = String.format("CreateCollectionRequest collectionName:%s", request.getCollectionName()); + FieldSchema vectorSchema = FieldSchema.newBuilder() + .setName(request.getVectorFieldName()) + .setDataType(DataType.FloatVector) + .setIsPrimaryKey(Boolean.FALSE) + .addTypeParams(KeyValuePair.newBuilder().setKey("dim").setValue(String.valueOf(request.getDimension())).build()) + .build(); + + FieldSchema idSchema = FieldSchema.newBuilder() + .setName("id") + .setDataType(DataType.valueOf(request.getPrimaryFieldType())) + .setIsPrimaryKey(Boolean.TRUE) + .setAutoID(request.getAutoID()) + .build(); + if(request.getPrimaryFieldType().equals("VarChar") && request.getMaxLength() != null){ + idSchema = idSchema.toBuilder().addTypeParams(KeyValuePair.newBuilder().setKey("max_length").setValue(String.valueOf(request.getMaxLength())).build()).build(); + } + + CollectionSchema schema = CollectionSchema.newBuilder() + .setName(request.getCollectionName()) + .addFields(vectorSchema) + .addFields(idSchema) + .setEnableDynamicField(Boolean.TRUE) + .build(); + + + CreateCollectionRequest createCollectionRequest = CreateCollectionRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setSchema(schema.toByteString()) + .build(); + + Status status = blockingStub.createCollection(createCollectionRequest); + rpcUtils.handleResponse(title, status); + + //create index + IndexParam indexParam = IndexParam.builder() + .metricType(IndexParam.MetricType.valueOf(request.getMetricType())) + .fieldName("vector") + .build(); + CreateIndexReq createIndexReq = CreateIndexReq.builder() + .indexParam(indexParam) + .collectionName(request.getCollectionName()) + .build(); + indexService.createIndex(blockingStub, createIndexReq); + //load collection + loadCollection(blockingStub, LoadCollectionReq.builder().collectionName(request.getCollectionName()).build()); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R createCollectionWithSchema(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreateCollectionWithSchemaReq request) { + String title = String.format("CreateCollectionRequest collectionName:%s", request.getCollectionName()); + + //convert CollectionSchema to io.milvus.grpc.CollectionSchema + CollectionSchema grpcSchema = CollectionSchema.newBuilder() + .setName(request.getCollectionName()) + .setDescription(request.getCollectionSchema().getDescription()) + .setEnableDynamicField(request.getCollectionSchema().getEnableDynamicField()) + .build(); + for (CreateCollectionWithSchemaReq.FieldSchema fieldSchema : request.getCollectionSchema().getFieldSchemaList()) { + grpcSchema = grpcSchema.toBuilder().addFields(SchemaUtils.convertToGrpcFieldSchema(fieldSchema)).build(); + } + + //create collection + CreateCollectionRequest createCollectionRequest = CreateCollectionRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setSchema(grpcSchema.toByteString()) + .build(); + + Status createCollectionResponse = blockingStub.createCollection(createCollectionRequest); + rpcUtils.handleResponse(title, createCollectionResponse); + + //create index + if(request.getIndexParams() != null && !request.getIndexParams().isEmpty()) { + for(IndexParam indexParam : request.getIndexParams()) { + CreateIndexReq createIndexReq = CreateIndexReq.builder() + .indexParam(indexParam) + .collectionName(request.getCollectionName()) + .build(); + indexService.createIndex(blockingStub, createIndexReq); + } + } + + //load collection + loadCollection(blockingStub, LoadCollectionReq.builder().collectionName(request.getCollectionName()).build()); + + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R listCollections(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub) { + ShowCollectionsRequest showCollectionsRequest = ShowCollectionsRequest.newBuilder() + .build(); + ShowCollectionsResponse response = milvusServiceBlockingStub.showCollections(showCollectionsRequest); + ListCollectionsResp a = ListCollectionsResp.builder() + .collectionNames(response.getCollectionNamesList()) + .build(); + + return R.success(a); + } + + public R dropCollection(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, DropCollectionReq request) { + + String title = String.format("DropCollectionRequest collectionName:%s", request.getCollectionName()); + DropCollectionRequest dropCollectionRequest = DropCollectionRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .build(); + Status status = milvusServiceBlockingStub.dropCollection(dropCollectionRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R hasCollection(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, HasCollectionReq request) { + HasCollectionRequest hasCollectionRequest = HasCollectionRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .build(); + BoolResponse response = milvusServiceBlockingStub.hasCollection(hasCollectionRequest); + return R.success(response.getValue()); + } + + public R describeCollection(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, DescribeCollectionReq request) { + //check collection exists + checkCollectionExist(milvusServiceBlockingStub, request.getCollectionName()); + + DescribeCollectionRequest describeCollectionRequest = DescribeCollectionRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .build(); + DescribeCollectionResponse response = milvusServiceBlockingStub.describeCollection(describeCollectionRequest); + + DescribeCollectionResp describeCollectionResp = DescribeCollectionResp.builder() + .collectionName(response.getCollectionName()) + .description(response.getSchema().getDescription()) + .numOfPartitions(response.getNumPartitions()) + .collectionSchema(SchemaUtils.convertFromGrpcCollectionSchema(response.getSchema())) + .autoID(response.getSchema().getFieldsList().stream().anyMatch(FieldSchema::getAutoID)) + .enableDynamicField(response.getSchema().getEnableDynamicField()) + .fieldNames(response.getSchema().getFieldsList().stream().map(FieldSchema::getName).collect(java.util.stream.Collectors.toList())) + .vectorFieldName(response.getSchema().getFieldsList().stream().filter(fieldSchema -> fieldSchema.getDataType() == DataType.FloatVector || fieldSchema.getDataType() == DataType.BinaryVector).map(FieldSchema::getName).collect(java.util.stream.Collectors.toList())) + .primaryFieldName(response.getSchema().getFieldsList().stream().filter(FieldSchema::getIsPrimaryKey).map(FieldSchema::getName).collect(java.util.stream.Collectors.toList()).get(0)) + .createTime(response.getCreatedTimestamp()) + .build(); + + return R.success(describeCollectionResp); + } + + public R renameCollection(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, RenameCollectionReq request) { + String title = String.format("RenameCollectionRequest collectionName:%s", request.getCollectionName()); + RenameCollectionRequest renameCollectionRequest = RenameCollectionRequest.newBuilder() + .setOldName(request.getCollectionName()) + .setNewName(request.getNewCollectionName()) + .build(); + Status status = milvusServiceBlockingStub.renameCollection(renameCollectionRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, LoadCollectionReq request) { + String title = String.format("LoadCollectionRequest collectionName:%s", request.getCollectionName()); + LoadCollectionRequest loadCollectionRequest = LoadCollectionRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .build(); + Status status = milvusServiceBlockingStub.loadCollection(loadCollectionRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R releaseCollection(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, ReleaseCollectionReq request) { + String title = String.format("ReleaseCollectionRequest collectionName:%s", request.getCollectionName()); + ReleaseCollectionRequest releaseCollectionRequest = ReleaseCollectionRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .build(); + Status status = milvusServiceBlockingStub.releaseCollection(releaseCollectionRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R getLoadState(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, GetLoadStateReq request) { + // getLoadState + String title = String.format("GetLoadStateRequest collectionName:%s", request.getCollectionName()); + GetLoadStateRequest getLoadStateRequest = GetLoadStateRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .build(); + GetLoadStateResponse response = milvusServiceBlockingStub.getLoadState(getLoadStateRequest); + rpcUtils.handleResponse(title, response.getStatus()); + return R.success(response.getState() == LoadState.LoadStateLoaded); + } + + public R getCollectionStats(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, GetCollectionStatsReq request) { + String title = String.format("GetCollectionStatisticsRequest collectionName:%s", request.getCollectionName()); + GetCollectionStatisticsRequest getCollectionStatisticsRequest = GetCollectionStatisticsRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .build(); + GetCollectionStatisticsResponse response = blockingStub.getCollectionStatistics(getCollectionStatisticsRequest); + + rpcUtils.handleResponse(title, response.getStatus()); + GetCollectionStatsResp getCollectionStatsResp = GetCollectionStatsResp.builder() + .numOfEntities(response.getStatsList().stream().filter(stat -> stat.getKey().equals("row_count")).map(stat -> Long.parseLong(stat.getValue())).findFirst().get()) + .build(); + return R.success(getCollectionStatsResp); + } +} diff --git a/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java b/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java new file mode 100644 index 000000000..ebbf5b2f9 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionReq.java @@ -0,0 +1,27 @@ +package io.milvus.v2.service.collection.request; + +import io.milvus.v2.common.DataType; +import io.milvus.v2.common.IndexParam; +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class CreateCollectionReq { + private String collectionName; + private Integer dimension; + + @Builder.Default + private String primaryFieldName = "id"; + @Builder.Default + private String primaryFieldType = DataType.VarChar.name(); + @Builder.Default + private Integer maxLength = 65535; + @Builder.Default + private String vectorFieldName = "vector"; + @Builder.Default + private String metricType = IndexParam.MetricType.IP.name(); + @Builder.Default + private Boolean autoID = Boolean.TRUE; +} diff --git a/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionWithSchemaReq.java b/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionWithSchemaReq.java new file mode 100644 index 000000000..1aeb90ef9 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/request/CreateCollectionWithSchemaReq.java @@ -0,0 +1,52 @@ +package io.milvus.v2.service.collection.request; + +import io.milvus.v2.common.DataType; +import io.milvus.v2.common.IndexParam; +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@SuperBuilder +@Data +public class CreateCollectionWithSchemaReq { + private String collectionName; + private CollectionSchema collectionSchema; + private List indexParams; + + @Data + @SuperBuilder + public static class CollectionSchema { + private List fieldSchemaList; + @Builder.Default + private String description = ""; + private Boolean enableDynamicField; + + public FieldSchema getField(String fieldName) { + for (FieldSchema field : fieldSchemaList) { + if (field.getName().equals(fieldName)) { + return field; + } + } + return null; + } + } + + @Data + @SuperBuilder + public static class FieldSchema { + //TODO: check here + private String name; + private DataType dataType; + @Builder.Default + private Integer maxLength = 65535; + private Integer dimension; + @Builder.Default + private Boolean isPrimaryKey = Boolean.FALSE; + @Builder.Default + private Boolean autoID = Boolean.FALSE; + } + + +} diff --git a/src/main/java/io/milvus/v2/service/collection/request/DescribeCollectionReq.java b/src/main/java/io/milvus/v2/service/collection/request/DescribeCollectionReq.java new file mode 100644 index 000000000..e27780c36 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/request/DescribeCollectionReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.collection.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DescribeCollectionReq { + private String collectionName; +} diff --git a/src/main/java/io/milvus/v2/service/collection/request/DropCollectionReq.java b/src/main/java/io/milvus/v2/service/collection/request/DropCollectionReq.java new file mode 100644 index 000000000..05134415b --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/request/DropCollectionReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.collection.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DropCollectionReq { + private String collectionName; +} diff --git a/src/main/java/io/milvus/v2/service/collection/request/GetCollectionStatsReq.java b/src/main/java/io/milvus/v2/service/collection/request/GetCollectionStatsReq.java new file mode 100644 index 000000000..17d6309f1 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/request/GetCollectionStatsReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.collection.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class GetCollectionStatsReq { + private String collectionName; +} diff --git a/src/main/java/io/milvus/v2/service/collection/request/GetLoadStateReq.java b/src/main/java/io/milvus/v2/service/collection/request/GetLoadStateReq.java new file mode 100644 index 000000000..e102645d1 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/request/GetLoadStateReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.collection.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class GetLoadStateReq { + private String collectionName; +} diff --git a/src/main/java/io/milvus/v2/service/collection/request/HasCollectionReq.java b/src/main/java/io/milvus/v2/service/collection/request/HasCollectionReq.java new file mode 100644 index 000000000..c0deabf38 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/request/HasCollectionReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.collection.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class HasCollectionReq { + private String collectionName; +} diff --git a/src/main/java/io/milvus/v2/service/collection/request/LoadCollectionReq.java b/src/main/java/io/milvus/v2/service/collection/request/LoadCollectionReq.java new file mode 100644 index 000000000..8c022bf37 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/request/LoadCollectionReq.java @@ -0,0 +1,13 @@ +package io.milvus.v2.service.collection.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class LoadCollectionReq { + private String collectionName; + private List partitionNames; +} diff --git a/src/main/java/io/milvus/v2/service/collection/request/ReleaseCollectionReq.java b/src/main/java/io/milvus/v2/service/collection/request/ReleaseCollectionReq.java new file mode 100644 index 000000000..d91888384 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/request/ReleaseCollectionReq.java @@ -0,0 +1,13 @@ +package io.milvus.v2.service.collection.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class ReleaseCollectionReq { + private String collectionName; + private List partitionNames; +} diff --git a/src/main/java/io/milvus/v2/service/collection/request/RenameCollectionReq.java b/src/main/java/io/milvus/v2/service/collection/request/RenameCollectionReq.java new file mode 100644 index 000000000..3c27cc57f --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/request/RenameCollectionReq.java @@ -0,0 +1,11 @@ +package io.milvus.v2.service.collection.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class RenameCollectionReq { + private String collectionName; + private String newCollectionName; +} diff --git a/src/main/java/io/milvus/v2/service/collection/response/DescribeCollectionResp.java b/src/main/java/io/milvus/v2/service/collection/response/DescribeCollectionResp.java new file mode 100644 index 000000000..3792335c0 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/response/DescribeCollectionResp.java @@ -0,0 +1,24 @@ +package io.milvus.v2.service.collection.response; + +import io.milvus.v2.service.collection.request.CreateCollectionWithSchemaReq; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class DescribeCollectionResp { + private String collectionName; + private String description; + private Long numOfPartitions; + + private List fieldNames; + private List vectorFieldName; + private String primaryFieldName; + private Boolean enableDynamicField; + private Boolean autoID; + + private CreateCollectionWithSchemaReq.CollectionSchema collectionSchema; + private Long createTime; +} diff --git a/src/main/java/io/milvus/v2/service/collection/response/GetCollectionStatsResp.java b/src/main/java/io/milvus/v2/service/collection/response/GetCollectionStatsResp.java new file mode 100644 index 000000000..f440e8e7f --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/response/GetCollectionStatsResp.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.collection.response; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class GetCollectionStatsResp { + private Long numOfEntities; +} diff --git a/src/main/java/io/milvus/v2/service/collection/response/ListCollectionsResp.java b/src/main/java/io/milvus/v2/service/collection/response/ListCollectionsResp.java new file mode 100644 index 000000000..ee28fdc46 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/collection/response/ListCollectionsResp.java @@ -0,0 +1,12 @@ +package io.milvus.v2.service.collection.response; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class ListCollectionsResp { + private List collectionNames; +} diff --git a/src/main/java/io/milvus/v2/service/index/IndexService.java b/src/main/java/io/milvus/v2/service/index/IndexService.java new file mode 100644 index 000000000..9291ffe87 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/index/IndexService.java @@ -0,0 +1,64 @@ +package io.milvus.v2.service.index; + +import io.milvus.grpc.*; +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.service.BaseService; +import io.milvus.v2.service.index.request.CreateIndexReq; +import io.milvus.v2.service.index.request.DescribeIndexReq; +import io.milvus.v2.service.index.request.DropIndexReq; +import io.milvus.v2.service.index.response.DescribeIndexResp; + +public class IndexService extends BaseService { + public R createIndex(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, CreateIndexReq request) { + String title = String.format("CreateIndexRequest collectionName:%s, fieldName:%s", + request.getCollectionName(), request.getIndexParam().getFieldName()); + CreateIndexRequest createIndexRequest = CreateIndexRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setIndexName(request.getIndexParam().getIndexName()) + .setFieldName(request.getIndexParam().getFieldName()) + .addExtraParams(KeyValuePair.newBuilder() + .setKey("index_type") + .setValue(String.valueOf(request.getIndexParam().getIndexType())) + .build()) + .addExtraParams(KeyValuePair.newBuilder() + .setKey("metric_type") + .setValue(String.valueOf(request.getIndexParam().getMetricType())) + .build()) + .build(); + + Status status = milvusServiceBlockingStub.createIndex(createIndexRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R dropIndex(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, DropIndexReq request) { + String title = String.format("DropIndexRequest collectionName:%s, fieldName:%s, indexName:%s", + request.getCollectionName(), request.getFieldName(), request.getIndexName()); + DropIndexRequest dropIndexRequest = DropIndexRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setFieldName(request.getFieldName()) + .setIndexName(request.getIndexName()) + .build(); + + Status status = milvusServiceBlockingStub.dropIndex(dropIndexRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R describeIndex(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, DescribeIndexReq request) { + String title = String.format("DescribeIndexRequest collectionName:%s, fieldName:%s, indexName:%s", + request.getCollectionName(), request.getFieldName(), request.getIndexName()); + DescribeIndexRequest describeIndexRequest = DescribeIndexRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setFieldName(request.getFieldName()) + .setIndexName(request.getIndexName()) + .build(); + + DescribeIndexResponse response = milvusServiceBlockingStub.describeIndex(describeIndexRequest); + rpcUtils.handleResponse(title, response.getStatus()); + + return convertUtils.convertToDescribeIndexResp(response); + } +} diff --git a/src/main/java/io/milvus/v2/service/index/request/CreateIndexReq.java b/src/main/java/io/milvus/v2/service/index/request/CreateIndexReq.java new file mode 100644 index 000000000..422c54123 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/index/request/CreateIndexReq.java @@ -0,0 +1,12 @@ +package io.milvus.v2.service.index.request; + +import io.milvus.v2.common.IndexParam; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class CreateIndexReq { + private String collectionName; + private IndexParam indexParam; +} diff --git a/src/main/java/io/milvus/v2/service/index/request/DescribeIndexReq.java b/src/main/java/io/milvus/v2/service/index/request/DescribeIndexReq.java new file mode 100644 index 000000000..1c89eca70 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/index/request/DescribeIndexReq.java @@ -0,0 +1,14 @@ +package io.milvus.v2.service.index.request; + +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DescribeIndexReq { + private String collectionName; + private String fieldName; + @Builder.Default + private String indexName = ""; +} diff --git a/src/main/java/io/milvus/v2/service/index/request/DropIndexReq.java b/src/main/java/io/milvus/v2/service/index/request/DropIndexReq.java new file mode 100644 index 000000000..75fdea3ea --- /dev/null +++ b/src/main/java/io/milvus/v2/service/index/request/DropIndexReq.java @@ -0,0 +1,14 @@ +package io.milvus.v2.service.index.request; + +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DropIndexReq { + private String collectionName; + private String fieldName; + @Builder.Default + private String indexName = ""; +} diff --git a/src/main/java/io/milvus/v2/service/index/response/DescribeIndexResp.java b/src/main/java/io/milvus/v2/service/index/response/DescribeIndexResp.java new file mode 100644 index 000000000..16d3aa22b --- /dev/null +++ b/src/main/java/io/milvus/v2/service/index/response/DescribeIndexResp.java @@ -0,0 +1,13 @@ +package io.milvus.v2.service.index.response; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DescribeIndexResp { + private String indexName; + private String indexType; + private String metricType; + private String fieldName; +} diff --git a/src/main/java/io/milvus/v2/service/partition/PartitionService.java b/src/main/java/io/milvus/v2/service/partition/PartitionService.java new file mode 100644 index 000000000..e0a3df78c --- /dev/null +++ b/src/main/java/io/milvus/v2/service/partition/PartitionService.java @@ -0,0 +1,88 @@ +package io.milvus.v2.service.partition; + +import io.milvus.grpc.CreatePartitionRequest; +import io.milvus.grpc.MilvusServiceGrpc; +import io.milvus.grpc.Status; +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.service.partition.request.*; +import io.milvus.v2.service.BaseService; +import io.milvus.v2.service.collection.request.LoadCollectionReq; +import io.milvus.v2.service.collection.request.ReleaseCollectionReq; + +import java.util.List; + +public class PartitionService extends BaseService { + public R createPartition(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreatePartitionReq request) { + String title = String.format("Create partition %s in collection %s", request.getPartitionName(), request.getCollectionName()); + + CreatePartitionRequest createPartitionRequest = io.milvus.grpc.CreatePartitionRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setPartitionName(request.getPartitionName()).build(); + + Status status = blockingStub.createPartition(createPartitionRequest); + rpcUtils.handleResponse(title, status); + + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R dropPartition(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, DropPartitionReq request) { + String title = String.format("Drop partition %s in collection %s", request.getPartitionName(), request.getCollectionName()); + + io.milvus.grpc.DropPartitionRequest dropPartitionRequest = io.milvus.grpc.DropPartitionRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setPartitionName(request.getPartitionName()).build(); + + Status status = blockingStub.dropPartition(dropPartitionRequest); + rpcUtils.handleResponse(title, status); + + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R hasPartition(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, HasPartitionReq request) { + String title = String.format("Has partition %s in collection %s", request.getPartitionName(), request.getCollectionName()); + + io.milvus.grpc.HasPartitionRequest hasPartitionRequest = io.milvus.grpc.HasPartitionRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setPartitionName(request.getPartitionName()).build(); + + io.milvus.grpc.BoolResponse boolResponse = blockingStub.hasPartition(hasPartitionRequest); + rpcUtils.handleResponse(title, boolResponse.getStatus()); + + return R.success(boolResponse.getValue()); + } + + public R> listPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, ListPartitionsReq request) { + String title = String.format("List partitions in collection %s", request.getCollectionName()); + + io.milvus.grpc.ShowPartitionsRequest showPartitionsRequest = io.milvus.grpc.ShowPartitionsRequest.newBuilder() + .setCollectionName(request.getCollectionName()).build(); + + io.milvus.grpc.ShowPartitionsResponse showPartitionsResponse = blockingStub.showPartitions(showPartitionsRequest); + rpcUtils.handleResponse(title, showPartitionsResponse.getStatus()); + + return R.success(showPartitionsResponse.getPartitionNamesList()); + } + + public R loadPartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, LoadPartitionsReq request) { + String title = String.format("Load partitions %s in collection %s", request.getPartitionNames(), request.getCollectionName()); + + io.milvus.grpc.LoadPartitionsRequest loadPartitionsRequest = io.milvus.grpc.LoadPartitionsRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .addAllPartitionNames(request.getPartitionNames()).build(); + Status status = blockingStub.loadPartitions(loadPartitionsRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R releasePartitions(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, ReleasePartitionsReq request) { + String title = String.format("Release partitions %s in collection %s", request.getPartitionNames(), request.getCollectionName()); + + io.milvus.grpc.ReleasePartitionsRequest releasePartitionsRequest = io.milvus.grpc.ReleasePartitionsRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .addAllPartitionNames(request.getPartitionNames()).build(); + Status status = blockingStub.releasePartitions(releasePartitionsRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } +} diff --git a/src/main/java/io/milvus/v2/service/partition/request/CreatePartitionReq.java b/src/main/java/io/milvus/v2/service/partition/request/CreatePartitionReq.java new file mode 100644 index 000000000..65fb764aa --- /dev/null +++ b/src/main/java/io/milvus/v2/service/partition/request/CreatePartitionReq.java @@ -0,0 +1,11 @@ +package io.milvus.v2.service.partition.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class CreatePartitionReq { + private String collectionName; + private String partitionName; +} diff --git a/src/main/java/io/milvus/v2/service/partition/request/DropPartitionReq.java b/src/main/java/io/milvus/v2/service/partition/request/DropPartitionReq.java new file mode 100644 index 000000000..600f611ec --- /dev/null +++ b/src/main/java/io/milvus/v2/service/partition/request/DropPartitionReq.java @@ -0,0 +1,11 @@ +package io.milvus.v2.service.partition.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DropPartitionReq { + private String collectionName; + private String partitionName; +} diff --git a/src/main/java/io/milvus/v2/service/partition/request/HasPartitionReq.java b/src/main/java/io/milvus/v2/service/partition/request/HasPartitionReq.java new file mode 100644 index 000000000..0f281e7e8 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/partition/request/HasPartitionReq.java @@ -0,0 +1,11 @@ +package io.milvus.v2.service.partition.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class HasPartitionReq { + private String collectionName; + private String partitionName; +} diff --git a/src/main/java/io/milvus/v2/service/partition/request/ListPartitionsReq.java b/src/main/java/io/milvus/v2/service/partition/request/ListPartitionsReq.java new file mode 100644 index 000000000..364627ef2 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/partition/request/ListPartitionsReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.partition.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class ListPartitionsReq { + private String collectionName; +} diff --git a/src/main/java/io/milvus/v2/service/partition/request/LoadPartitionsReq.java b/src/main/java/io/milvus/v2/service/partition/request/LoadPartitionsReq.java new file mode 100644 index 000000000..308185498 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/partition/request/LoadPartitionsReq.java @@ -0,0 +1,13 @@ +package io.milvus.v2.service.partition.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class LoadPartitionsReq { + private String collectionName; + private List partitionNames; +} diff --git a/src/main/java/io/milvus/v2/service/partition/request/ReleasePartitionsReq.java b/src/main/java/io/milvus/v2/service/partition/request/ReleasePartitionsReq.java new file mode 100644 index 000000000..4210387c3 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/partition/request/ReleasePartitionsReq.java @@ -0,0 +1,13 @@ +package io.milvus.v2.service.partition.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class ReleasePartitionsReq { + private String collectionName; + private List partitionNames; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/RoleService.java b/src/main/java/io/milvus/v2/service/rbac/RoleService.java new file mode 100644 index 000000000..d26f46310 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/RoleService.java @@ -0,0 +1,133 @@ +package io.milvus.v2.service.rbac; + +import io.milvus.grpc.*; +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.service.BaseService; +import io.milvus.v2.service.rbac.request.*; +import io.milvus.v2.service.rbac.response.DescribeRoleResp; + +import java.util.List; +import java.util.stream.Collectors; + +public class RoleService extends BaseService { + + public R> listRoles(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub) { + String title = "listRoles"; + SelectRoleRequest request = SelectRoleRequest.newBuilder().build(); + SelectRoleResponse response = blockingStub.selectRole(request); + + rpcUtils.handleResponse(title, response.getStatus()); + List roles = response.getResultsList().stream().map(roleResult -> roleResult.getRole().getName()).collect(Collectors.toList()); + return R.success(roles); + } + + public R createRole(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreateRoleReq request) { + String title = "createRole"; + CreateRoleRequest createRoleRequest = CreateRoleRequest.newBuilder() + .setEntity(RoleEntity.newBuilder() + .setName(request.getRoleName()) + .build()) + .build(); + Status status = blockingStub.createRole(createRoleRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R describeRole(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, DescribeRoleReq request) { + String title = "describeRole"; + SelectGrantRequest selectGrantRequest = SelectGrantRequest.newBuilder() + .setEntity(GrantEntity.newBuilder() + .setRole(RoleEntity.newBuilder() + .setName(request.getRoleName()) + .build()) + .build()) + .build(); + SelectGrantResponse response = blockingStub.selectGrant(selectGrantRequest); + rpcUtils.handleResponse(title, response.getStatus()); + DescribeRoleResp describeRoleResp = DescribeRoleResp.builder() + .grantInfos(response.getEntitiesList().stream().map(grantEntity -> DescribeRoleResp.GrantInfo.builder() + .dbName(grantEntity.getDbName()) + .objectName(grantEntity.getObjectName()) + .objectType(grantEntity.getObject().getName()) + .privilege(grantEntity.getGrantor().getPrivilege().getName()) + .grantor(grantEntity.getGrantor().getUser().getName()) + .build()).collect(Collectors.toList())) + .build(); + return R.success(describeRoleResp); + } + + public R dropRole(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, DropRoleReq request) { + String title = "dropRole"; + DropRoleRequest dropRoleRequest = DropRoleRequest.newBuilder() + .setRoleName(request.getRoleName()) + .build(); + Status status = blockingStub.dropRole(dropRoleRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R grantPrivilege(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, GrantPrivilegeReq request) { + String title = "grantPrivilege"; + GrantEntity entity = GrantEntity.newBuilder() + .setRole(RoleEntity.newBuilder() + .setName(request.getRoleName()) + .build()) + .setObjectName(request.getObjectName()) + .setObject(ObjectEntity.newBuilder().setName(request.getObjectType()).build()) + .setGrantor(GrantorEntity.newBuilder() + .setPrivilege(PrivilegeEntity.newBuilder().setName(request.getPrivilege()).build()).build()) + .build(); + OperatePrivilegeRequest operatePrivilegeRequest = OperatePrivilegeRequest.newBuilder() + .setEntity(entity) + .setType(OperatePrivilegeType.Grant) + .build(); + Status status = blockingStub.operatePrivilege(operatePrivilegeRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R revokePrivilege(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, RevokePrivilegeReq request) { + String title = "revokePrivilege"; + GrantEntity entity = GrantEntity.newBuilder() + .setRole(RoleEntity.newBuilder() + .setName(request.getRoleName()) + .build()) + .setObjectName(request.getObjectName()) + .setObject(ObjectEntity.newBuilder().setName(request.getObjectType()).build()) + .setGrantor(GrantorEntity.newBuilder() + .setPrivilege(PrivilegeEntity.newBuilder().setName(request.getPrivilege()).build()).build()) + .build(); + OperatePrivilegeRequest operatePrivilegeRequest = OperatePrivilegeRequest.newBuilder() + .setEntity(entity) + .setType(OperatePrivilegeType.Revoke) + .build(); + Status status = blockingStub.operatePrivilege(operatePrivilegeRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R grantRole(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, GrantRoleReq request) { + String title = "grantRole"; + OperateUserRoleRequest operateUserRoleRequest = OperateUserRoleRequest.newBuilder() + .setUsername(request.getUserName()) + .setRoleName(request.getRoleName()) + .setType(OperateUserRoleType.AddUserToRole) + .build(); + Status status = blockingStub.operateUserRole(operateUserRoleRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R revokeRole(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, RevokeRoleReq request) { + String title = "grantRole"; + OperateUserRoleRequest operateUserRoleRequest = OperateUserRoleRequest.newBuilder() + .setUsername(request.getUserName()) + .setRoleName(request.getRoleName()) + .setType(OperateUserRoleType.RemoveUserFromRole) + .build(); + Status status = blockingStub.operateUserRole(operateUserRoleRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } +} diff --git a/src/main/java/io/milvus/v2/service/rbac/UserService.java b/src/main/java/io/milvus/v2/service/rbac/UserService.java new file mode 100644 index 000000000..7db4f0cac --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/UserService.java @@ -0,0 +1,73 @@ +package io.milvus.v2.service.rbac; + +import io.milvus.grpc.*; +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.service.rbac.request.*; +import io.milvus.v2.service.BaseService; +import io.milvus.v2.service.rbac.response.DescribeUserResp; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; +import java.util.stream.Collectors; + +public class UserService extends BaseService { + + public R> listUsers(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub) { + String title = "list users"; + ListCredUsersRequest request = ListCredUsersRequest.newBuilder().build(); + ListCredUsersResponse response = blockingStub.listCredUsers(request); + rpcUtils.handleResponse(title, response.getStatus()); + return R.success(response.getUsernamesList()); + } + + public R describeUser(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, DescribeUserReq request) { + String title = String.format("describe user %s", request.getUserName()); + // TODO: check user exists + SelectUserRequest selectUserRequest = SelectUserRequest.newBuilder() + .setUser(UserEntity.newBuilder().setName(request.getUserName()).build()) + .setIncludeRoleInfo(Boolean.TRUE) + .build(); + io.milvus.grpc.SelectUserResponse response = blockingStub.selectUser(selectUserRequest); + rpcUtils.handleResponse(title, response.getStatus()); + DescribeUserResp describeUserResp = DescribeUserResp.builder() + .roles(response.getResultsList().isEmpty()? null : response.getResultsList().get(0).getRolesList().stream().map(RoleEntity::getName).collect(Collectors.toList())) + .build(); + return R.success(describeUserResp); + } + + public R createUser(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreateUserReq request) { + String title = String.format("create user %s", request.getUserName()); + CreateCredentialRequest createCredentialRequest = CreateCredentialRequest.newBuilder() + .setUsername(request.getUserName()) + .setPassword(Base64.getEncoder().encodeToString(request.getPassword().getBytes(StandardCharsets.UTF_8))) + .build(); + Status response = blockingStub.createCredential(createCredentialRequest); + rpcUtils.handleResponse(title, response); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + + public R updatePassword(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, UpdatePasswordReq request) { + String title = String.format("update password for user %s", request.getUserName()); + UpdateCredentialRequest updateCredentialRequest = UpdateCredentialRequest.newBuilder() + .setUsername(request.getUserName()) + .setOldPassword(Base64.getEncoder().encodeToString(request.getPassword().getBytes(StandardCharsets.UTF_8))) + .setNewPassword(Base64.getEncoder().encodeToString(request.getNewPassword().getBytes(StandardCharsets.UTF_8))) + .build(); + Status response = blockingStub.updateCredential(updateCredentialRequest); + rpcUtils.handleResponse(title, response); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R dropUser(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, DropUserReq request) { + String title = String.format("drop user %s", request.getUserName()); + DeleteCredentialRequest deleteCredentialRequest = DeleteCredentialRequest.newBuilder() + .setUsername(request.getUserName()) + .build(); + Status response = blockingStub.deleteCredential(deleteCredentialRequest); + rpcUtils.handleResponse(title, response); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } +} diff --git a/src/main/java/io/milvus/v2/service/rbac/request/CreateRoleReq.java b/src/main/java/io/milvus/v2/service/rbac/request/CreateRoleReq.java new file mode 100644 index 000000000..55b53d092 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/request/CreateRoleReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.rbac.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class CreateRoleReq { + private String roleName; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/request/CreateUserReq.java b/src/main/java/io/milvus/v2/service/rbac/request/CreateUserReq.java new file mode 100644 index 000000000..57b317510 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/request/CreateUserReq.java @@ -0,0 +1,11 @@ +package io.milvus.v2.service.rbac.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class CreateUserReq { + private String userName; + private String password; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/request/DescribeRoleReq.java b/src/main/java/io/milvus/v2/service/rbac/request/DescribeRoleReq.java new file mode 100644 index 000000000..9ca09de4e --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/request/DescribeRoleReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.rbac.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DescribeRoleReq { + private String roleName; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/request/DescribeUserReq.java b/src/main/java/io/milvus/v2/service/rbac/request/DescribeUserReq.java new file mode 100644 index 000000000..444e4a6b9 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/request/DescribeUserReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.rbac.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DescribeUserReq { + private String userName; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/request/DropRoleReq.java b/src/main/java/io/milvus/v2/service/rbac/request/DropRoleReq.java new file mode 100644 index 000000000..aa2544375 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/request/DropRoleReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.rbac.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DropRoleReq { + private String roleName; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/request/DropUserReq.java b/src/main/java/io/milvus/v2/service/rbac/request/DropUserReq.java new file mode 100644 index 000000000..cb4333e85 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/request/DropUserReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.rbac.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DropUserReq { + private String userName; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/request/GrantPrivilegeReq.java b/src/main/java/io/milvus/v2/service/rbac/request/GrantPrivilegeReq.java new file mode 100644 index 000000000..9455322a4 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/request/GrantPrivilegeReq.java @@ -0,0 +1,13 @@ +package io.milvus.v2.service.rbac.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class GrantPrivilegeReq { + private String roleName; + private String objectType; + private String privilege; + private String objectName; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/request/GrantRoleReq.java b/src/main/java/io/milvus/v2/service/rbac/request/GrantRoleReq.java new file mode 100644 index 000000000..d0fa77ddc --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/request/GrantRoleReq.java @@ -0,0 +1,11 @@ +package io.milvus.v2.service.rbac.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class GrantRoleReq { + private String userName; + private String roleName; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/request/RevokePrivilegeReq.java b/src/main/java/io/milvus/v2/service/rbac/request/RevokePrivilegeReq.java new file mode 100644 index 000000000..f739e6113 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/request/RevokePrivilegeReq.java @@ -0,0 +1,14 @@ +package io.milvus.v2.service.rbac.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class RevokePrivilegeReq { + private String roleName; + private String dbName; + private String objectType; + private String privilege; + private String objectName; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/request/RevokeRoleReq.java b/src/main/java/io/milvus/v2/service/rbac/request/RevokeRoleReq.java new file mode 100644 index 000000000..303ff6359 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/request/RevokeRoleReq.java @@ -0,0 +1,11 @@ +package io.milvus.v2.service.rbac.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class RevokeRoleReq { + private String userName; + private String roleName; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/request/UpdatePasswordReq.java b/src/main/java/io/milvus/v2/service/rbac/request/UpdatePasswordReq.java new file mode 100644 index 000000000..d0348e294 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/request/UpdatePasswordReq.java @@ -0,0 +1,12 @@ +package io.milvus.v2.service.rbac.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class UpdatePasswordReq { + private String userName; + private String password; + private String newPassword; +} diff --git a/src/main/java/io/milvus/v2/service/rbac/response/DescribeRoleResp.java b/src/main/java/io/milvus/v2/service/rbac/response/DescribeRoleResp.java new file mode 100644 index 000000000..da2275335 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/response/DescribeRoleResp.java @@ -0,0 +1,23 @@ +package io.milvus.v2.service.rbac.response; + +import io.milvus.grpc.GrantEntity; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class DescribeRoleResp { + List grantInfos; + + @Data + @SuperBuilder + public static class GrantInfo { + private String objectType; + private String privilege; + private String objectName; + private String dbName; + private String grantor; + } +} diff --git a/src/main/java/io/milvus/v2/service/rbac/response/DescribeUserResp.java b/src/main/java/io/milvus/v2/service/rbac/response/DescribeUserResp.java new file mode 100644 index 000000000..de5a3d758 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/rbac/response/DescribeUserResp.java @@ -0,0 +1,12 @@ +package io.milvus.v2.service.rbac.response; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class DescribeUserResp { + private List roles; +} diff --git a/src/main/java/io/milvus/v2/service/utility/UtilityService.java b/src/main/java/io/milvus/v2/service/utility/UtilityService.java new file mode 100644 index 000000000..ef0aa5525 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/utility/UtilityService.java @@ -0,0 +1,86 @@ +package io.milvus.v2.service.utility; + +import io.milvus.grpc.FlushResponse; +import io.milvus.grpc.MilvusServiceGrpc; +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.service.utility.request.FlushReq; +import io.milvus.v2.service.BaseService; +import io.milvus.v2.service.utility.request.AlterAliasReq; +import io.milvus.v2.service.utility.request.CreateAliasReq; +import io.milvus.v2.service.utility.request.DropAliasReq; +import io.milvus.v2.service.utility.response.DescribeAliasResp; +import io.milvus.v2.service.utility.response.ListAliasResp; + +public class UtilityService extends BaseService { + public R flush(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, FlushReq request) { + String title = String.format("Flush collection %s", request.getCollectionName()); + io.milvus.grpc.FlushRequest flushRequest = io.milvus.grpc.FlushRequest.newBuilder() + .addCollectionNames(request.getCollectionName()) + .build(); + FlushResponse status = blockingStub.flush(flushRequest); + rpcUtils.handleResponse(title, status.getStatus()); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R createAlias(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreateAliasReq request) { + String title = String.format("Create alias %s for collection %s", request.getAlias(), request.getCollectionName()); + io.milvus.grpc.CreateAliasRequest createAliasRequest = io.milvus.grpc.CreateAliasRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setAlias(request.getAlias()) + .build(); + io.milvus.grpc.Status status = blockingStub.createAlias(createAliasRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R dropAlias(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, DropAliasReq request) { + String title = String.format("Drop alias %s", request.getAlias()); + io.milvus.grpc.DropAliasRequest dropAliasRequest = io.milvus.grpc.DropAliasRequest.newBuilder() + .setAlias(request.getAlias()) + .build(); + io.milvus.grpc.Status status = blockingStub.dropAlias(dropAliasRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R alterAlias(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, AlterAliasReq request) { + String title = String.format("Alter alias %s for collection %s", request.getAlias(), request.getCollectionName()); + io.milvus.grpc.AlterAliasRequest alterAliasRequest = io.milvus.grpc.AlterAliasRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setAlias(request.getAlias()) + .build(); + io.milvus.grpc.Status status = blockingStub.alterAlias(alterAliasRequest); + rpcUtils.handleResponse(title, status); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R describeAlias(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String alias) { + String title = String.format("Describe alias %s", alias); + io.milvus.grpc.DescribeAliasRequest describeAliasRequest = io.milvus.grpc.DescribeAliasRequest.newBuilder() + .setAlias(alias) + .build(); + io.milvus.grpc.DescribeAliasResponse response = blockingStub.describeAlias(describeAliasRequest); + + rpcUtils.handleResponse(title, response.getStatus()); + + return R.success(DescribeAliasResp.builder() + .collectionName(response.getCollection()) + .alias(response.getAlias()) + .build()); + } + + public R listAliases(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub) { + String title = "List aliases"; + io.milvus.grpc.ListAliasesRequest listAliasesRequest = io.milvus.grpc.ListAliasesRequest.newBuilder() + .build(); + io.milvus.grpc.ListAliasesResponse response = blockingStub.listAliases(listAliasesRequest); + + rpcUtils.handleResponse(title, response.getStatus()); + + return R.success(ListAliasResp.builder() + .collectionName(response.getCollectionName()) + .alias(response.getAliasesList()) + .build()); + } +} diff --git a/src/main/java/io/milvus/v2/service/utility/request/AlterAliasReq.java b/src/main/java/io/milvus/v2/service/utility/request/AlterAliasReq.java new file mode 100644 index 000000000..a4dc28f86 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/utility/request/AlterAliasReq.java @@ -0,0 +1,11 @@ +package io.milvus.v2.service.utility.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class AlterAliasReq { + private String alias; + private String collectionName; +} diff --git a/src/main/java/io/milvus/v2/service/utility/request/CreateAliasReq.java b/src/main/java/io/milvus/v2/service/utility/request/CreateAliasReq.java new file mode 100644 index 000000000..2a8f97c16 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/utility/request/CreateAliasReq.java @@ -0,0 +1,11 @@ +package io.milvus.v2.service.utility.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class CreateAliasReq { + private String alias; + private String collectionName; +} diff --git a/src/main/java/io/milvus/v2/service/utility/request/DropAliasReq.java b/src/main/java/io/milvus/v2/service/utility/request/DropAliasReq.java new file mode 100644 index 000000000..f564b0272 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/utility/request/DropAliasReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.utility.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DropAliasReq { + private String alias; +} diff --git a/src/main/java/io/milvus/v2/service/utility/request/FlushReq.java b/src/main/java/io/milvus/v2/service/utility/request/FlushReq.java new file mode 100644 index 000000000..4bbdd0b41 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/utility/request/FlushReq.java @@ -0,0 +1,10 @@ +package io.milvus.v2.service.utility.request; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class FlushReq { + private String collectionName; +} diff --git a/src/main/java/io/milvus/v2/service/utility/response/DescribeAliasResp.java b/src/main/java/io/milvus/v2/service/utility/response/DescribeAliasResp.java new file mode 100644 index 000000000..e43b32a96 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/utility/response/DescribeAliasResp.java @@ -0,0 +1,11 @@ +package io.milvus.v2.service.utility.response; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +@Data +@SuperBuilder +public class DescribeAliasResp { + private String collectionName; + private String alias; +} diff --git a/src/main/java/io/milvus/v2/service/utility/response/ListAliasResp.java b/src/main/java/io/milvus/v2/service/utility/response/ListAliasResp.java new file mode 100644 index 000000000..df1baa486 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/utility/response/ListAliasResp.java @@ -0,0 +1,13 @@ +package io.milvus.v2.service.utility.response; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class ListAliasResp { + private String collectionName; + private List alias; +} diff --git a/src/main/java/io/milvus/v2/service/vector/VectorService.java b/src/main/java/io/milvus/v2/service/vector/VectorService.java new file mode 100644 index 000000000..110c8724b --- /dev/null +++ b/src/main/java/io/milvus/v2/service/vector/VectorService.java @@ -0,0 +1,134 @@ +package io.milvus.v2.service.vector; + +import io.milvus.grpc.*; +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.response.DescCollResponseWrapper; +import io.milvus.v2.service.BaseService; +import io.milvus.v2.service.collection.CollectionService; +import io.milvus.v2.service.collection.request.DescribeCollectionReq; +import io.milvus.v2.service.collection.response.DescribeCollectionResp; +import io.milvus.v2.service.index.request.DescribeIndexReq; +import io.milvus.v2.service.index.response.DescribeIndexResp; +import io.milvus.v2.service.index.IndexService; +import io.milvus.v2.service.vector.request.*; +import io.milvus.v2.service.vector.response.QueryResp; +import io.milvus.v2.service.vector.response.GetResp; +import io.milvus.v2.service.vector.response.SearchResp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class VectorService extends BaseService { + Logger logger = LoggerFactory.getLogger(VectorService.class); + public CollectionService collectionService = new CollectionService(); + public IndexService indexService = new IndexService(); + + public R insert(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, InsertReq request){ + String title = String.format("InsertRequest collectionName:%s", request.getCollectionName()); + checkCollectionExist(blockingStub, request.getCollectionName()); + DescribeCollectionRequest describeCollectionRequest = DescribeCollectionRequest.newBuilder() + .setCollectionName(request.getCollectionName()).build(); + DescribeCollectionResponse descResp = blockingStub.describeCollection(describeCollectionRequest); + + MutationResult response = blockingStub.insert(dataUtils.convertGrpcInsertRequest(request, new DescCollResponseWrapper(descResp))); + rpcUtils.handleResponse(title, response.getStatus()); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R upsert(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, UpsertReq request) { + String title = String.format("UpsertRequest collectionName:%s", request.getCollectionName()); + + checkCollectionExist(milvusServiceBlockingStub, request.getCollectionName()); + + DescribeCollectionRequest describeCollectionRequest = DescribeCollectionRequest.newBuilder() + .setCollectionName(request.getCollectionName()).build(); + DescribeCollectionResponse descResp = milvusServiceBlockingStub.describeCollection(describeCollectionRequest); + + MutationResult response = milvusServiceBlockingStub.upsert(dataUtils.convertGrpcUpsertRequest(request, new DescCollResponseWrapper(descResp))); + rpcUtils.handleResponse(title, response.getStatus()); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R query(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, QueryReq request) { + String title = String.format("QueryRequest collectionName:%s", request.getCollectionName()); + checkCollectionExist(milvusServiceBlockingStub, request.getCollectionName()); + R descR = collectionService.describeCollection(milvusServiceBlockingStub, DescribeCollectionReq.builder().collectionName(request.getCollectionName()).build()); + if(request.getOutputFields() == null){ + request.setOutputFields(descR.getData().getFieldNames()); + } + QueryResults response = milvusServiceBlockingStub.query(vectorUtils.ConvertToGrpcQueryRequest(request)); + rpcUtils.handleResponse(title, response.getStatus()); + + QueryResp res = QueryResp.builder() + .queryResults(convertUtils.getEntities(response)) + .build(); + return R.success(res); + + } + + public R search(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, SearchReq request) { + String title = String.format("SearchRequest collectionName:%s", request.getCollectionName()); + + checkCollectionExist(milvusServiceBlockingStub, request.getCollectionName()); + R descR = collectionService.describeCollection(milvusServiceBlockingStub, DescribeCollectionReq.builder().collectionName(request.getCollectionName()).build()); + if (request.getVectorFieldName() == null) { + request.setVectorFieldName(descR.getData().getVectorFieldName().get(0)); + } + if(request.getOutFields() == null){ + request.setOutFields(descR.getData().getFieldNames()); + } + DescribeIndexReq describeIndexReq = DescribeIndexReq.builder() + .collectionName(request.getCollectionName()) + .fieldName(request.getVectorFieldName()) + .build(); + R respR = indexService.describeIndex(milvusServiceBlockingStub, describeIndexReq); + + SearchRequest searchRequest = vectorUtils.ConvertToGrpcSearchRequest(respR.getData().getMetricType(), request); + + SearchResults response = milvusServiceBlockingStub.search(searchRequest); + rpcUtils.handleResponse(title, response.getStatus()); + + SearchResp searchResp = SearchResp.builder() + .searchResults(convertUtils.getEntities(response)) + .build(); + return R.success(searchResp); + } + + public R delete(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, DeleteReq request) { + String title = String.format("DeleteRequest collectionName:%s", request.getCollectionName()); + checkCollectionExist(milvusServiceBlockingStub, request.getCollectionName()); + R respR = collectionService.describeCollection(milvusServiceBlockingStub, DescribeCollectionReq.builder().collectionName(request.getCollectionName()).build()); + if(request.getExpr() == null){ + request.setExpr(vectorUtils.getExprById(respR.getData().getPrimaryFieldName(), request.getIds())); + } + DeleteRequest deleteRequest = DeleteRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .setPartitionName(request.getPartitionName()) + .setExpr(request.getExpr()) + .build(); + MutationResult response = milvusServiceBlockingStub.delete(deleteRequest); + rpcUtils.handleResponse(title, response.getStatus()); + return R.success(new RpcStatus(RpcStatus.SUCCESS_MSG)); + } + + public R get(MilvusServiceGrpc.MilvusServiceBlockingStub milvusServiceBlockingStub, GetReq request) { + String title = String.format("GetRequest collectionName:%s", request.getCollectionName()); + checkCollectionExist(milvusServiceBlockingStub, request.getCollectionName()); + DescribeCollectionReq describeCollectionReq = DescribeCollectionReq.builder() + .collectionName(request.getCollectionName()) + .build(); + R resp = collectionService.describeCollection(milvusServiceBlockingStub, describeCollectionReq); + + String expr = vectorUtils.getExprById(resp.getData().getPrimaryFieldName(), request.getIds()); + QueryReq queryReq = QueryReq.builder() + .collectionName(request.getCollectionName()) + .expr(expr) + .build(); + R queryResp = query(milvusServiceBlockingStub, queryReq); + + GetResp getResp = GetResp.builder() + .getResults(queryResp.getData().getQueryResults()) + .build(); + return R.success(getResp); + } +} diff --git a/src/main/java/io/milvus/v2/service/vector/request/DeleteReq.java b/src/main/java/io/milvus/v2/service/vector/request/DeleteReq.java new file mode 100644 index 000000000..c898f8c50 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/vector/request/DeleteReq.java @@ -0,0 +1,17 @@ +package io.milvus.v2.service.vector.request; + +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class DeleteReq { + private String collectionName; + @Builder.Default + private String partitionName = ""; + private String expr; + private List ids; +} diff --git a/src/main/java/io/milvus/v2/service/vector/request/GetReq.java b/src/main/java/io/milvus/v2/service/vector/request/GetReq.java new file mode 100644 index 000000000..882f8bbb5 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/vector/request/GetReq.java @@ -0,0 +1,16 @@ +package io.milvus.v2.service.vector.request; + +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class GetReq { + private String collectionName; + @Builder.Default + private String partitionName = ""; + private List ids; +} diff --git a/src/main/java/io/milvus/v2/service/vector/request/InsertReq.java b/src/main/java/io/milvus/v2/service/vector/request/InsertReq.java new file mode 100644 index 000000000..c9d1eddc5 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/vector/request/InsertReq.java @@ -0,0 +1,20 @@ +package io.milvus.v2.service.vector.request; + +import com.alibaba.fastjson.JSONObject; +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +@Data +@SuperBuilder +public class InsertReq { + //private List<> fields; + private List insertData; + private String collectionName; + @Builder.Default + private String partitionName = ""; +} diff --git a/src/main/java/io/milvus/v2/service/vector/request/QueryReq.java b/src/main/java/io/milvus/v2/service/vector/request/QueryReq.java new file mode 100644 index 000000000..80137fcec --- /dev/null +++ b/src/main/java/io/milvus/v2/service/vector/request/QueryReq.java @@ -0,0 +1,27 @@ +package io.milvus.v2.service.vector.request; + +import io.milvus.v2.common.ConsistencyLevel; +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.ArrayList; +import java.util.List; + +@Data +@SuperBuilder +public class QueryReq { + private String collectionName; + @Builder.Default + private List partitionNames = new ArrayList<>(); + private List outputFields; + private String expr; + private long travelTimestamp; + private long guaranteeTimestamp; + private long gracefulTime; + @Builder.Default + private ConsistencyLevel consistencyLevel = ConsistencyLevel.BOUNDED; + private long offset; + private long limit; + private boolean ignoreGrowing; +} diff --git a/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java b/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java new file mode 100644 index 000000000..026c23988 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java @@ -0,0 +1,36 @@ +package io.milvus.v2.service.vector.request; + +import io.milvus.v2.common.ConsistencyLevel; +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.ArrayList; +import java.util.List; + +@Data +@SuperBuilder +public class SearchReq { + private String collectionName; + @Builder.Default + private List partitionNames = new ArrayList<>(); + private String vectorFieldName; + private int topK; + private String expr; + private List outFields; + private List vectors; + private long offset; + private long limit; + + //private final Long NQ; + @Builder.Default + private int roundDecimal = -1; + @Builder.Default + private String params = "{\"nprobe\": 10}"; + private long guaranteeTimestamp; + @Builder.Default + private Long gracefulTime = 5000L; + @Builder.Default + private ConsistencyLevel consistencyLevel = ConsistencyLevel.BOUNDED; + private boolean ignoreGrowing; +} diff --git a/src/main/java/io/milvus/v2/service/vector/request/UpsertReq.java b/src/main/java/io/milvus/v2/service/vector/request/UpsertReq.java new file mode 100644 index 000000000..a4e828b98 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/vector/request/UpsertReq.java @@ -0,0 +1,23 @@ +package io.milvus.v2.service.vector.request; + +import com.alibaba.fastjson.JSONObject; +import lombok.Builder; +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +@Data +@SuperBuilder +public class UpsertReq { + private List> upsertData; + private String collectionName; + @Builder.Default + private String partitionName = ""; + + public List getUpsertData() { + return new ArrayList(); + } +} diff --git a/src/main/java/io/milvus/v2/service/vector/response/GetResp.java b/src/main/java/io/milvus/v2/service/vector/response/GetResp.java new file mode 100644 index 000000000..472d3c859 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/vector/response/GetResp.java @@ -0,0 +1,12 @@ +package io.milvus.v2.service.vector.response; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@Data +@SuperBuilder +public class GetResp { + public List getResults; +} diff --git a/src/main/java/io/milvus/v2/service/vector/response/QueryResp.java b/src/main/java/io/milvus/v2/service/vector/response/QueryResp.java new file mode 100644 index 000000000..a65eb36c4 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/vector/response/QueryResp.java @@ -0,0 +1,19 @@ +package io.milvus.v2.service.vector.response; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; +import java.util.Map; + +@Data +@SuperBuilder +public class QueryResp { + private List queryResults; + + @Data + @SuperBuilder + public static class QueryResult { + private Map fields; + } +} diff --git a/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java b/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java new file mode 100644 index 000000000..40d617ee1 --- /dev/null +++ b/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java @@ -0,0 +1,20 @@ +package io.milvus.v2.service.vector.response; + +import lombok.Data; +import lombok.experimental.SuperBuilder; + +import java.util.List; +import java.util.Map; + +@Data +@SuperBuilder +public class SearchResp { + private List searchResults; + + @Data + @SuperBuilder + public static class SearchResult { + private Map fields; + private Float score; + } +} diff --git a/src/main/java/io/milvus/v2/utils/ClientUtils.java b/src/main/java/io/milvus/v2/utils/ClientUtils.java new file mode 100644 index 000000000..900f94516 --- /dev/null +++ b/src/main/java/io/milvus/v2/utils/ClientUtils.java @@ -0,0 +1,112 @@ +package io.milvus.v2.utils; + +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; +import io.grpc.stub.MetadataUtils; +import io.milvus.grpc.ListDatabasesRequest; +import io.milvus.grpc.ListDatabasesResponse; +import io.milvus.grpc.MilvusServiceGrpc; +import io.milvus.v2.client.ConnectConfig; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.concurrent.TimeUnit; + +public class ClientUtils { + Logger logger = LoggerFactory.getLogger(ClientUtils.class); + RpcUtils rpcUtils = new RpcUtils(); + public ManagedChannel getChannel(ConnectConfig connectConfig){ + ManagedChannel channel = null; + + Metadata metadata = new Metadata(); + + metadata.put(Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER), Base64.getEncoder().encodeToString(connectConfig.getAuthorization().getBytes(StandardCharsets.UTF_8))); + if (StringUtils.isNotEmpty(connectConfig.getDatabaseName())) { + metadata.put(Metadata.Key.of("dbname", Metadata.ASCII_STRING_MARSHALLER), connectConfig.getDatabaseName()); + } + + try { + if (StringUtils.isNotEmpty(connectConfig.getServerPemPath())) { + // one-way tls + SslContext sslContext = GrpcSslContexts.forClient() + .trustManager(new File(connectConfig.getServerPemPath())) + .build(); + + NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectConfig.getHost(), connectConfig.getPort()) + .overrideAuthority(connectConfig.getServerName()) + .sslContext(sslContext) + .maxInboundMessageSize(Integer.MAX_VALUE) + .keepAliveTime(connectConfig.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS) + .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS) + .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) + .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) + .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + if(connectConfig.isSecure()){ + builder.useTransportSecurity(); + } + channel = builder.build(); + } else if (StringUtils.isNotEmpty(connectConfig.getClientPemPath()) + && StringUtils.isNotEmpty(connectConfig.getClientKeyPath()) + && StringUtils.isNotEmpty(connectConfig.getCaPemPath())) { + // tow-way tls + SslContext sslContext = GrpcSslContexts.forClient() + .trustManager(new File(connectConfig.getCaPemPath())) + .keyManager(new File(connectConfig.getClientPemPath()), new File(connectConfig.getClientKeyPath())) + .build(); + + NettyChannelBuilder builder = NettyChannelBuilder.forAddress(connectConfig.getHost(), connectConfig.getPort()) + .sslContext(sslContext) + .maxInboundMessageSize(Integer.MAX_VALUE) + .keepAliveTime(connectConfig.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS) + .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS) + .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) + .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) + .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + if(connectConfig.isSecure()){ + builder.useTransportSecurity(); + } + if (StringUtils.isNotEmpty(connectConfig.getServerName())) { + builder.overrideAuthority(connectConfig.getServerName()); + } + channel = builder.build(); + } else { + // no tls + ManagedChannelBuilder builder = ManagedChannelBuilder.forAddress(connectConfig.getHost(), connectConfig.getPort()) + .usePlaintext() + .maxInboundMessageSize(Integer.MAX_VALUE) + .keepAliveTime(connectConfig.getKeepAliveTimeMs(), TimeUnit.MILLISECONDS) + .keepAliveTimeout(connectConfig.getKeepAliveTimeoutMs(), TimeUnit.MILLISECONDS) + .keepAliveWithoutCalls(connectConfig.isKeepAliveWithoutCalls()) + .idleTimeout(connectConfig.getIdleTimeoutMs(), TimeUnit.MILLISECONDS) + .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata)); + if(connectConfig.isSecure()){ + builder.useTransportSecurity(); + } + channel = builder.build(); + } + } catch (IOException e) { + logger.error("Failed to open credentials file, error:{}\n", e.getMessage()); + } + assert channel != null; + return channel; + } + + public void checkDatabaseExist(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, String dbName) { + String title = String.format("Check database %s exist", dbName); + ListDatabasesRequest listDatabasesRequest = ListDatabasesRequest.newBuilder().build(); + ListDatabasesResponse response = blockingStub.listDatabases(listDatabasesRequest); + rpcUtils.handleResponse(title, response.getStatus()); + if (!response.getDbNamesList().contains(dbName)) { + throw new IllegalArgumentException("Database " + dbName + " not exist"); + } + } +} diff --git a/src/main/java/io/milvus/v2/utils/ConvertUtils.java b/src/main/java/io/milvus/v2/utils/ConvertUtils.java new file mode 100644 index 000000000..a5da5a109 --- /dev/null +++ b/src/main/java/io/milvus/v2/utils/ConvertUtils.java @@ -0,0 +1,67 @@ +package io.milvus.v2.utils; + +import io.milvus.grpc.*; +import io.milvus.param.R; +import io.milvus.response.QueryResultsWrapper; +import io.milvus.response.SearchResultsWrapper; +import io.milvus.v2.service.index.response.DescribeIndexResp; +import io.milvus.v2.service.vector.response.QueryResp; +import io.milvus.v2.service.vector.response.SearchResp; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class ConvertUtils { + public List getEntities(QueryResults response) { + QueryResultsWrapper queryResultsWrapper = new QueryResultsWrapper(response); + List entities = new ArrayList<>(); + + if(response.getFieldsDataList().stream().anyMatch(fieldData -> fieldData.getFieldName().equals("count(*)"))){ + Map countField = new HashMap<>(); + long numOfEntities = response.getFieldsDataList().stream().filter(fieldData -> fieldData.getFieldName().equals("count(*)")).map(FieldData::getScalars).collect(Collectors.toList()).get(0).getLongData().getData(0); + countField.put("count(*)", numOfEntities); + + QueryResp.QueryResult queryResult = QueryResp.QueryResult.builder() + .fields(countField) + .build(); + entities.add(queryResult); + + return entities; + } + queryResultsWrapper.getRowRecords().forEach(rowRecord -> { + QueryResp.QueryResult queryResult = QueryResp.QueryResult.builder() + .fields(rowRecord.getFieldValues()) + .build(); + entities.add(queryResult); + }); + return entities; + } + + public List getEntities(SearchResults response) { + SearchResultsWrapper searchResultsWrapper = new SearchResultsWrapper(response.getResults()); + + return searchResultsWrapper.getIDScore(0).stream().map(idScore -> SearchResp.SearchResult.builder() + .fields(idScore.getFieldValues()) + .score(idScore.getScore()) + .build()).collect(Collectors.toList()); + } + + public R convertToDescribeIndexResp(DescribeIndexResponse response) { + DescribeIndexResp describeIndexResp = DescribeIndexResp.builder() + .indexName(response.getIndexDescriptions(0).getIndexName()) + .fieldName(response.getIndexDescriptions(0).getFieldName()) + .build(); + List params = response.getIndexDescriptions(0).getParamsList(); + for(KeyValuePair param : params) { + if (param.getKey().equals("index_type")) { + describeIndexResp.setIndexType(param.getValue()); + }else if (param.getKey().equals("metric_type")) { + describeIndexResp.setMetricType(param.getValue()); + } + } + return R.success(describeIndexResp); + } +} diff --git a/src/main/java/io/milvus/v2/utils/DataUtils.java b/src/main/java/io/milvus/v2/utils/DataUtils.java new file mode 100644 index 000000000..1c6a7e688 --- /dev/null +++ b/src/main/java/io/milvus/v2/utils/DataUtils.java @@ -0,0 +1,469 @@ +package io.milvus.v2.utils; + +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.google.protobuf.ByteString; +import io.milvus.exception.IllegalResponseException; +import io.milvus.exception.ParamException; +import io.milvus.grpc.*; +import io.milvus.param.Constant; +import io.milvus.param.ParamUtils; +import io.milvus.param.collection.FieldType; +import io.milvus.param.dml.InsertParam; +import io.milvus.response.DescCollResponseWrapper; +import io.milvus.v2.service.vector.request.InsertReq; +import io.milvus.v2.service.vector.request.UpsertReq; +import lombok.NonNull; +import org.apache.commons.lang3.StringUtils; + +import java.nio.ByteBuffer; +import java.util.*; +import java.util.stream.Collectors; + +public class DataUtils { + private InsertRequest.Builder insertBuilder; + private UpsertRequest.Builder upsertBuilder; + private static final Set vectorDataType = new HashSet() {{ + add(DataType.FloatVector); + add(DataType.BinaryVector); + }}; + + public InsertRequest convertGrpcInsertRequest(@NonNull InsertReq requestParam, + DescCollResponseWrapper wrapper) { + String collectionName = requestParam.getCollectionName(); + + // generate insert request builder + MsgBase msgBase = MsgBase.newBuilder().setMsgType(MsgType.Insert).build(); + insertBuilder = InsertRequest.newBuilder() + .setCollectionName(collectionName) + .setBase(msgBase) + .setNumRows(requestParam.getInsertData().size()); +// if (StringUtils.isNotEmpty(requestParam.getDatabaseName())) { +// insertBuilder.setDbName(requestParam.getDatabaseName()); +// } + fillFieldsData(requestParam, wrapper); + return insertBuilder.build(); + } + public UpsertRequest convertGrpcUpsertRequest(@NonNull UpsertReq requestParam, + DescCollResponseWrapper wrapper) { + String collectionName = requestParam.getCollectionName(); + + // currently, not allow to upsert for collection whose primary key is auto-generated + FieldType pk = wrapper.getPrimaryField(); + if (pk.isAutoID()) { + throw new ParamException(String.format("Upsert don't support autoID==True, collection: %s", + requestParam.getCollectionName())); + } + + // generate upsert request builder + MsgBase msgBase = MsgBase.newBuilder().setMsgType(MsgType.Insert).build(); + upsertBuilder = UpsertRequest.newBuilder() + .setCollectionName(collectionName) + .setBase(msgBase) + .setNumRows(requestParam.getUpsertData().size()); +// if (StringUtils.isNotEmpty(requestParam.getDatabaseName())) { +// upsertBuilder.setDbName(requestParam.getDatabaseName()); +// } + fillFieldsData(requestParam, wrapper); + return upsertBuilder.build(); + } + + private void addFieldsData(io.milvus.grpc.FieldData value) { + if (insertBuilder != null) { + insertBuilder.addFieldsData(value); + } else if (upsertBuilder != null) { + upsertBuilder.addFieldsData(value); + } + } + + private void setPartitionName(String value) { + if (insertBuilder != null) { + insertBuilder.setPartitionName(value); + } else if (upsertBuilder != null) { + upsertBuilder.setPartitionName(value); + } + } + + private void fillFieldsData(UpsertReq requestParam, DescCollResponseWrapper wrapper) { + // set partition name only when there is no partition key field + String partitionName = requestParam.getPartitionName(); + boolean isPartitionKeyEnabled = false; + for (FieldType fieldType : wrapper.getFields()) { + if (fieldType.isPartitionKey()) { + isPartitionKeyEnabled = true; + break; + } + } + if (isPartitionKeyEnabled) { + if (partitionName != null && !partitionName.isEmpty()) { + String msg = "Collection " + requestParam.getCollectionName() + " has partition key, not allow to specify partition name"; + throw new ParamException(msg); + } + } else if (partitionName != null) { + this.setPartitionName(partitionName); + } + + // convert insert data + List rowFields = requestParam.getUpsertData(); + + checkAndSetRowData(wrapper, rowFields); + + } + + private void fillFieldsData(InsertReq requestParam, DescCollResponseWrapper wrapper) { + // set partition name only when there is no partition key field + String partitionName = requestParam.getPartitionName(); + boolean isPartitionKeyEnabled = false; + for (FieldType fieldType : wrapper.getFields()) { + if (fieldType.isPartitionKey()) { + isPartitionKeyEnabled = true; + break; + } + } + if (isPartitionKeyEnabled) { + if (partitionName != null && !partitionName.isEmpty()) { + String msg = "Collection " + requestParam.getCollectionName() + " has partition key, not allow to specify partition name"; + throw new ParamException(msg); + } + } else if (partitionName != null) { + this.setPartitionName(partitionName); + } + + // convert insert data + List rowFields = requestParam.getInsertData(); + + checkAndSetRowData(wrapper, rowFields); + + } + + private void checkAndSetRowData(DescCollResponseWrapper wrapper, List rows) { + List fieldTypes = wrapper.getFields(); + + Map nameInsertInfo = new HashMap<>(); + ParamUtils.InsertDataInfo insertDynamicDataInfo = ParamUtils.InsertDataInfo.builder().fieldType( + FieldType.newBuilder() + .withName(Constant.DYNAMIC_FIELD_NAME) + .withDataType(DataType.JSON) + .withIsDynamic(true) + .build()) + .data(new LinkedList<>()).build(); + for (JSONObject row : rows) { + for (FieldType fieldType : fieldTypes) { + String fieldName = fieldType.getName(); + ParamUtils.InsertDataInfo insertDataInfo = nameInsertInfo.getOrDefault(fieldName, ParamUtils.InsertDataInfo.builder() + .fieldType(fieldType).data(new LinkedList<>()).build()); + + // check normalField + Object rowFieldData = row.get(fieldName); + if (rowFieldData != null) { + if (fieldType.isAutoID()) { + String msg = "The primary key: " + fieldName + " is auto generated, no need to input."; + throw new ParamException(msg); + } + checkFieldData(fieldType, Lists.newArrayList(rowFieldData), false); + + insertDataInfo.getData().add(rowFieldData); + nameInsertInfo.put(fieldName, insertDataInfo); + } else { + // check if autoId + if (!fieldType.isAutoID()) { + String msg = "The field: " + fieldType.getName() + " is not provided."; + throw new ParamException(msg); + } + } + } + + // deal with dynamicField + if (wrapper.getEnableDynamicField()) { + JSONObject dynamicField = new JSONObject(); + for (String rowFieldName : row.keySet()) { + if (!nameInsertInfo.containsKey(rowFieldName)) { + dynamicField.put(rowFieldName, row.get(rowFieldName)); + } + } + insertDynamicDataInfo.getData().add(dynamicField); + } + } + + for (String fieldNameKey : nameInsertInfo.keySet()) { + ParamUtils.InsertDataInfo insertDataInfo = nameInsertInfo.get(fieldNameKey); + this.addFieldsData(genFieldData(insertDataInfo.getFieldType(), insertDataInfo.getData())); + } + if (wrapper.getEnableDynamicField()) { + this.addFieldsData(genFieldData(insertDynamicDataInfo.getFieldType(), insertDynamicDataInfo.getData(), Boolean.TRUE)); + } + } + + public InsertRequest buildInsertRequest() { + if (insertBuilder != null) { + return insertBuilder.build(); + } + throw new ParamException("Unable to build insert request since no input"); + } + private static FieldData genFieldData(FieldType fieldType, List objects) { + return genFieldData(fieldType, objects, Boolean.FALSE); + } + + @SuppressWarnings("unchecked") + private static FieldData genFieldData(FieldType fieldType, List objects, boolean isDynamic) { + if (objects == null) { + throw new ParamException("Cannot generate FieldData from null object"); + } + DataType dataType = fieldType.getDataType(); + String fieldName = fieldType.getName(); + FieldData.Builder builder = FieldData.newBuilder(); + if (vectorDataType.contains(dataType)) { + VectorField vectorField = genVectorField(dataType, objects); + return builder.setFieldName(fieldName).setType(dataType).setVectors(vectorField).build(); + } else { + ScalarField scalarField = genScalarField(fieldType, objects); + if (isDynamic) { + return builder.setType(dataType).setScalars(scalarField).setIsDynamic(true).build(); + } + return builder.setFieldName(fieldName).setType(dataType).setScalars(scalarField).build(); + } + } + + @SuppressWarnings("unchecked") + private static VectorField genVectorField(DataType dataType, List objects) { + if (dataType == DataType.FloatVector) { + List floats = new ArrayList<>(); + // each object is List + for (Object object : objects) { + if (object instanceof List) { + List list = (List) object; + floats.addAll(list); + } else { + throw new ParamException("The type of FloatVector must be List"); + } + } + + int dim = floats.size() / objects.size(); + FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build(); + return VectorField.newBuilder().setDim(dim).setFloatVector(floatArray).build(); + } else if (dataType == DataType.BinaryVector) { + ByteBuffer totalBuf = null; + int dim = 0; + // each object is ByteBuffer + for (Object object : objects) { + ByteBuffer buf = (ByteBuffer) object; + if (totalBuf == null) { + totalBuf = ByteBuffer.allocate(buf.position() * objects.size()); + totalBuf.put(buf.array()); + dim = buf.position() * 8; + } else { + totalBuf.put(buf.array()); + } + } + + assert totalBuf != null; + ByteString byteString = ByteString.copyFrom(totalBuf.array()); + return VectorField.newBuilder().setDim(dim).setBinaryVector(byteString).build(); + } + + throw new ParamException("Illegal vector dataType:" + dataType); + } + + private static ScalarField genScalarField(FieldType fieldType, List objects) { + if (fieldType.getDataType() == DataType.Array) { + ArrayArray.Builder builder = ArrayArray.newBuilder(); + for (Object object : objects) { + List temp = (List)object; + ScalarField arrayField = genScalarField(fieldType.getElementType(), temp); + builder.addData(arrayField); + } + + return ScalarField.newBuilder().setArrayData(builder.build()).build(); + } else { + return genScalarField(fieldType.getDataType(), objects); + } + } + + private static ScalarField genScalarField(DataType dataType, List objects) { + switch (dataType) { + case None: + case UNRECOGNIZED: + throw new ParamException("Cannot support this dataType:" + dataType); + case Int64: { + List longs = objects.stream().map(p -> (Long) p).collect(Collectors.toList()); + LongArray longArray = LongArray.newBuilder().addAllData(longs).build(); + return ScalarField.newBuilder().setLongData(longArray).build(); + } + case Int32: + case Int16: + case Int8: { + List integers = objects.stream().map(p -> p instanceof Short ? ((Short) p).intValue() : (Integer) p).collect(Collectors.toList()); + IntArray intArray = IntArray.newBuilder().addAllData(integers).build(); + return ScalarField.newBuilder().setIntData(intArray).build(); + } + case Bool: { + List booleans = objects.stream().map(p -> (Boolean) p).collect(Collectors.toList()); + BoolArray boolArray = BoolArray.newBuilder().addAllData(booleans).build(); + return ScalarField.newBuilder().setBoolData(boolArray).build(); + } + case Float: { + List floats = objects.stream().map(p -> (Float) p).collect(Collectors.toList()); + FloatArray floatArray = FloatArray.newBuilder().addAllData(floats).build(); + return ScalarField.newBuilder().setFloatData(floatArray).build(); + } + case Double: { + List doubles = objects.stream().map(p -> (Double) p).collect(Collectors.toList()); + DoubleArray doubleArray = DoubleArray.newBuilder().addAllData(doubles).build(); + return ScalarField.newBuilder().setDoubleData(doubleArray).build(); + } + case String: + case VarChar: { + List strings = objects.stream().map(p -> (String) p).collect(Collectors.toList()); + StringArray stringArray = StringArray.newBuilder().addAllData(strings).build(); + return ScalarField.newBuilder().setStringData(stringArray).build(); + } + case JSON: { + List byteStrings = objects.stream().map(p -> ByteString.copyFromUtf8(((JSONObject) p).toJSONString())) + .collect(Collectors.toList()); + JSONArray jsonArray = JSONArray.newBuilder().addAllData(byteStrings).build(); + return ScalarField.newBuilder().setJsonData(jsonArray).build(); + } + default: + throw new ParamException("Illegal scalar dataType:" + dataType); + } + } + private static void checkFieldData(FieldType fieldSchema, InsertParam.Field fieldData) { + List values = fieldData.getValues(); + checkFieldData(fieldSchema, values, false); + } + + private static void checkFieldData(FieldType fieldSchema, List values, boolean verifyElementType) { + HashMap errMsgs = getTypeErrorMsg(); + DataType dataType = verifyElementType ? fieldSchema.getElementType() : fieldSchema.getDataType(); + + if (verifyElementType && values.size() > fieldSchema.getMaxCapacity()) { + throw new ParamException(String.format("Array field '%s' length: %d exceeds max capacity: %d", + fieldSchema.getName(), values.size(), fieldSchema.getMaxCapacity())); + } + + switch (dataType) { + case FloatVector: { + int dim = fieldSchema.getDimension(); + for (int i = 0; i < values.size(); ++i) { + // is List<> ? + Object value = values.get(i); + if (!(value instanceof List)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + // is List ? + List temp = (List)value; + for (Object v : temp) { + if (!(v instanceof Float)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + } + + // check dimension + if (temp.size() != dim) { + String msg = "Incorrect dimension for field '%s': the no.%d vector's dimension: %d is not equal to field's dimension: %d"; + throw new ParamException(String.format(msg, fieldSchema.getName(), i, temp.size(), dim)); + } + } + } + break; + case BinaryVector: { + int dim = fieldSchema.getDimension(); + for (int i = 0; i < values.size(); ++i) { + Object value = values.get(i); + // is ByteBuffer? + if (!(value instanceof ByteBuffer)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + + // check dimension + ByteBuffer v = (ByteBuffer)value; + if (v.position()*8 != dim) { + String msg = "Incorrect dimension for field '%s': the no.%d vector's dimension: %d is not equal to field's dimension: %d"; + throw new ParamException(String.format(msg, fieldSchema.getName(), i, v.position()*8, dim)); + } + } + } + break; + case Int64: + for (Object value : values) { + if (!(value instanceof Long)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + } + break; + case Int32: + case Int16: + case Int8: + for (Object value : values) { + if (!(value instanceof Short) && !(value instanceof Integer)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + } + break; + case Bool: + for (Object value : values) { + if (!(value instanceof Boolean)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + } + break; + case Float: + for (Object value : values) { + if (!(value instanceof Float)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + } + break; + case Double: + for (Object value : values) { + if (!(value instanceof Double)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + } + break; + case VarChar: + case String: + for (Object value : values) { + if (!(value instanceof String)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + } + break; + case JSON: + for (Object value : values) { + if (!(value instanceof JSONObject)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + } + break; + case Array: + for (Object value : values) { + if (!(value instanceof List)) { + throw new ParamException(String.format(errMsgs.get(dataType), fieldSchema.getName())); + } + + List temp = (List)value; + checkFieldData(fieldSchema, temp, true); + } + break; + default: + throw new IllegalResponseException("Unsupported data type returned by FieldData"); + } + } + public static HashMap getTypeErrorMsg() { + final HashMap typeErrMsg = new HashMap<>(); + typeErrMsg.put(DataType.None, "Type mismatch for field '%s': the field type is illegal"); + typeErrMsg.put(DataType.Bool, "Type mismatch for field '%s': Bool field value type must be Boolean"); + typeErrMsg.put(DataType.Int8, "Type mismatch for field '%s': Int32/Int16/Int8 field value type must be Short or Integer"); + typeErrMsg.put(DataType.Int16, "Type mismatch for field '%s': Int32/Int16/Int8 field value type must be Short or Integer"); + typeErrMsg.put(DataType.Int32, "Type mismatch for field '%s': Int32/Int16/Int8 field value type must be Short or Integer"); + typeErrMsg.put(DataType.Int64, "Type mismatch for field '%s': Int64 field value type must be Long"); + typeErrMsg.put(DataType.Float, "Type mismatch for field '%s': Float field value type must be Float"); + typeErrMsg.put(DataType.Double, "Type mismatch for field '%s': Double field value type must be Double"); + typeErrMsg.put(DataType.String, "Type mismatch for field '%s': String field value type must be String"); + typeErrMsg.put(DataType.VarChar, "Type mismatch for field '%s': VarChar field value type must be String"); + typeErrMsg.put(DataType.FloatVector, "Type mismatch for field '%s': Float vector field's value type must be List"); + typeErrMsg.put(DataType.BinaryVector, "Type mismatch for field '%s': Binary vector field's value type must be ByteBuffer"); + return typeErrMsg; + } +} diff --git a/src/main/java/io/milvus/v2/utils/RpcUtils.java b/src/main/java/io/milvus/v2/utils/RpcUtils.java new file mode 100644 index 000000000..baabee922 --- /dev/null +++ b/src/main/java/io/milvus/v2/utils/RpcUtils.java @@ -0,0 +1,39 @@ +package io.milvus.v2.utils; + +import io.milvus.exception.ServerException; +import io.milvus.grpc.ErrorCode; +import io.milvus.grpc.Status; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RpcUtils { + + protected static final Logger logger = LoggerFactory.getLogger(RpcUtils.class); + + public void handleResponse(String requestInfo, Status status) { + // the server made a change for error code: + // for 2.2.x, error code is status.getErrorCode() + // for 2.3.x, error code is status.getCode(), and the status.getErrorCode() + // is also assigned according to status.getCode() + // + // For error cases: + // if we use 2.3.4 sdk to interact with 2.3.x server, getCode() is non-zero, getErrorCode() is non-zero + // if we use 2.3.4 sdk to interact with 2.2.x server, getCode() is zero, getErrorCode() is non-zero + // if we use <=2.3.3 sdk to interact with 2.2.x/2.3.x server, getCode() is not available, getErrorCode() is non-zero + + if (status.getCode() != 0 || !status.getErrorCode().equals(ErrorCode.Success)) { + logger.error("{} failed, error code: {}, reason: {}", requestInfo, + status.getCode() > 0 ? status.getCode() : status.getErrorCode().getNumber(), + status.getReason()); + + // 2.3.4 sdk to interact with 2.2.x server, the getCode() is zero, here we reset its value to getErrorCode() + int code = status.getCode(); + if (code == 0) { + code = status.getErrorCode().getNumber(); + } + throw new ServerException(status.getReason(), code, status.getErrorCode()); + } + + logger.debug("{} successfully!", requestInfo); + } +} diff --git a/src/main/java/io/milvus/v2/utils/SchemaUtils.java b/src/main/java/io/milvus/v2/utils/SchemaUtils.java new file mode 100644 index 000000000..a0e117c24 --- /dev/null +++ b/src/main/java/io/milvus/v2/utils/SchemaUtils.java @@ -0,0 +1,56 @@ +package io.milvus.v2.utils; + +import io.milvus.grpc.CollectionSchema; +import io.milvus.grpc.DataType; +import io.milvus.grpc.FieldSchema; +import io.milvus.grpc.KeyValuePair; +import io.milvus.v2.service.collection.request.CreateCollectionWithSchemaReq; + +import java.util.ArrayList; +import java.util.List; + +public class SchemaUtils { + public static FieldSchema convertToGrpcFieldSchema(CreateCollectionWithSchemaReq.FieldSchema fieldSchema) { + FieldSchema schema = FieldSchema.newBuilder() + .setName(fieldSchema.getName()) + .setDataType(DataType.valueOf(fieldSchema.getDataType().name())) + .setIsPrimaryKey(fieldSchema.getIsPrimaryKey()) + .setAutoID(fieldSchema.getAutoID()) + .build(); + if(fieldSchema.getDimension() != null){ + schema = schema.toBuilder().addTypeParams(KeyValuePair.newBuilder().setKey("dim").setValue(String.valueOf(fieldSchema.getDimension())).build()).build(); + } + if(fieldSchema.getDataType() == io.milvus.v2.common.DataType.VarChar && fieldSchema.getMaxLength() != null){ + schema = schema.toBuilder().addTypeParams(KeyValuePair.newBuilder().setKey("max_length").setValue(String.valueOf(fieldSchema.getMaxLength())).build()).build(); + } + return schema; + } + + public static CreateCollectionWithSchemaReq.CollectionSchema convertFromGrpcCollectionSchema(CollectionSchema schema) { + CreateCollectionWithSchemaReq.CollectionSchema collectionSchema = CreateCollectionWithSchemaReq.CollectionSchema.builder() + .description(schema.getDescription()) + .enableDynamicField(schema.getEnableDynamicField()) + .build(); + List fieldSchemas = new ArrayList<>(); + for (FieldSchema fieldSchema : schema.getFieldsList()) { + fieldSchemas.add(convertFromGrpcFieldSchema(fieldSchema)); + } + collectionSchema.setFieldSchemaList(fieldSchemas); + return collectionSchema; + } + + private static CreateCollectionWithSchemaReq.FieldSchema convertFromGrpcFieldSchema(FieldSchema fieldSchema) { + CreateCollectionWithSchemaReq.FieldSchema schema = CreateCollectionWithSchemaReq.FieldSchema.builder() + .name(fieldSchema.getName()) + .dataType(io.milvus.v2.common.DataType.valueOf(fieldSchema.getDataType().name())) + .isPrimaryKey(fieldSchema.getIsPrimaryKey()) + .autoID(fieldSchema.getAutoID()) + .build(); + for (KeyValuePair keyValuePair : fieldSchema.getTypeParamsList()) { + if(keyValuePair.getKey().equals("dim")){ + schema.setDimension(Integer.parseInt(keyValuePair.getValue())); + } + } + return schema; + } +} diff --git a/src/main/java/io/milvus/v2/utils/VectorUtils.java b/src/main/java/io/milvus/v2/utils/VectorUtils.java new file mode 100644 index 000000000..a828f80e4 --- /dev/null +++ b/src/main/java/io/milvus/v2/utils/VectorUtils.java @@ -0,0 +1,222 @@ +package io.milvus.v2.utils; + +import com.google.protobuf.ByteString; +import io.milvus.common.clientenum.ConsistencyLevelEnum; +import io.milvus.common.utils.JacksonUtils; +import io.milvus.exception.ParamException; +import io.milvus.grpc.*; +import io.milvus.param.Constant; +import io.milvus.v2.service.vector.request.QueryReq; +import io.milvus.v2.service.vector.request.SearchReq; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Map; + +public class VectorUtils { + + public QueryRequest ConvertToGrpcQueryRequest(QueryReq request){ + long guaranteeTimestamp = getGuaranteeTimestamp(ConsistencyLevelEnum.valueOf(request.getConsistencyLevel().name()), + request.getGuaranteeTimestamp(), request.getGracefulTime()); + QueryRequest.Builder builder = QueryRequest.newBuilder() + .setCollectionName(request.getCollectionName()) + .addAllPartitionNames(request.getPartitionNames()) + .addAllOutputFields(request.getOutputFields()) + .setExpr(request.getExpr()) + .setTravelTimestamp(request.getTravelTimestamp()) + .setGuaranteeTimestamp(guaranteeTimestamp); + + // a new parameter from v2.2.9, if user didn't specify consistency level, set this parameter to true + if (request.getConsistencyLevel() == null) { + builder.setUseDefaultConsistency(true); + } else { + builder.setConsistencyLevelValue(request.getConsistencyLevel().getCode()); + } + + // set offset and limit value. + // directly pass the two values, the server will verify them. + long offset = request.getOffset(); + if (offset > 0) { + builder.addQueryParams(KeyValuePair.newBuilder() + .setKey(Constant.OFFSET) + .setValue(String.valueOf(offset)) + .build()); + } + + long limit = request.getLimit(); + if (limit > 0) { + builder.addQueryParams(KeyValuePair.newBuilder() + .setKey(Constant.LIMIT) + .setValue(String.valueOf(limit)) + .build()); + } + + // ignore growing + builder.addQueryParams(KeyValuePair.newBuilder() + .setKey(Constant.IGNORE_GROWING) + .setValue(String.valueOf(request.isIgnoreGrowing())) + .build()); + + return builder.build(); + + } + + private static long getGuaranteeTimestamp(ConsistencyLevelEnum consistencyLevel, + long guaranteeTimestamp, Long gracefulTime){ + if(consistencyLevel == null){ + return 1L; + } + switch (consistencyLevel){ + case STRONG: + guaranteeTimestamp = 0L; + break; + case BOUNDED: + guaranteeTimestamp = (new Date()).getTime() - gracefulTime; + break; + case EVENTUALLY: + guaranteeTimestamp = 1L; + break; + } + return guaranteeTimestamp; + } + + public SearchRequest ConvertToGrpcSearchRequest(String metricType, SearchReq request) { + SearchRequest.Builder builder = SearchRequest.newBuilder() + .setDbName("") + .setCollectionName(request.getCollectionName()); + if (!request.getPartitionNames().isEmpty()) { + request.getPartitionNames().forEach(builder::addPartitionNames); + } + + builder.addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.METRIC_TYPE) + .setValue(metricType) + .build()); + + + // prepare target vectors + // TODO: check target vector dimension(use DescribeCollection get schema to compare) + PlaceholderType plType = PlaceholderType.None; + List vectors = request.getVectors(); + List byteStrings = new ArrayList<>(); + for (Object vector : vectors) { + if (vector instanceof List) { + plType = PlaceholderType.FloatVector; + List list = (List) vector; + ByteBuffer buf = ByteBuffer.allocate(Float.BYTES * list.size()); + buf.order(ByteOrder.LITTLE_ENDIAN); + list.forEach(buf::putFloat); + + byte[] array = buf.array(); + ByteString bs = ByteString.copyFrom(array); + byteStrings.add(bs); + } else if (vector instanceof ByteBuffer) { + plType = PlaceholderType.BinaryVector; + ByteBuffer buf = (ByteBuffer) vector; + byte[] array = buf.array(); + ByteString bs = ByteString.copyFrom(array); + byteStrings.add(bs); + } else { + String msg = "Search target vector type is illegal(Only allow List or ByteBuffer)"; + throw new ParamException(msg); + } + } + + PlaceholderValue.Builder pldBuilder = PlaceholderValue.newBuilder() + .setTag(Constant.VECTOR_TAG) + .setType(plType); + byteStrings.forEach(pldBuilder::addValues); + + PlaceholderValue plv = pldBuilder.build(); + PlaceholderGroup placeholderGroup = PlaceholderGroup.newBuilder() + .addPlaceholders(plv) + .build(); + + ByteString byteStr = placeholderGroup.toByteString(); + builder.setPlaceholderGroup(byteStr); + //builder.setNq(request.getNQ()); + + // search parameters + builder.addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.VECTOR_FIELD) + .setValue(request.getVectorFieldName()) + .build()) + .addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.TOP_K) + .setValue(String.valueOf(request.getTopK())) + .build()) + .addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.ROUND_DECIMAL) + .setValue(String.valueOf(request.getRoundDecimal())) + .build()) + .addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.IGNORE_GROWING) + .setValue(String.valueOf(request.isIgnoreGrowing())) + .build()); + + builder.addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.OFFSET) + .setValue(String.valueOf(request.getOffset())) + .build()); + + if (null != request.getParams() && !request.getParams().isEmpty()) { + try { + builder.addSearchParams( + KeyValuePair.newBuilder() + .setKey(Constant.PARAMS) + .setValue(request.getParams()) + .build()); + } catch (IllegalArgumentException e) { + throw new ParamException(e.getMessage() + e.getCause().getMessage()); + } + } + + if (!request.getOutFields().isEmpty()) { + request.getOutFields().forEach(builder::addOutputFields); + } + + // always use expression since dsl is discarded + builder.setDslType(DslType.BoolExprV1); + if (request.getExpr() != null && !request.getExpr().isEmpty()) { + builder.setDsl(request.getExpr()); + } + + long guaranteeTimestamp = getGuaranteeTimestamp(ConsistencyLevelEnum.valueOf(request.getConsistencyLevel().name()), + request.getGuaranteeTimestamp(), request.getGracefulTime()); + //builder.setTravelTimestamp(request.getTravelTimestamp()); + builder.setGuaranteeTimestamp(guaranteeTimestamp); + + // a new parameter from v2.2.9, if user didn't specify consistency level, set this parameter to true + if (request.getConsistencyLevel() == null) { + builder.setUseDefaultConsistency(true); + } else { + builder.setConsistencyLevelValue(request.getConsistencyLevel().getCode()); + } + + return builder.build(); + } + + public String getExprById(String primaryFieldName, List ids) { + StringBuilder sb = new StringBuilder(); + sb.append(primaryFieldName).append(" in ["); + for (Object id : ids) { + if (id instanceof String) { + sb.append("\"").append(id.toString()).append("\","); + } else { + sb.append(id.toString()).append(","); + } + } + sb.deleteCharAt(sb.length() - 1); + sb.append("]"); + return sb.toString(); + } +} diff --git a/src/test/java/io/milvus/v2/BaseTest.java b/src/test/java/io/milvus/v2/BaseTest.java new file mode 100644 index 000000000..52b55eebc --- /dev/null +++ b/src/test/java/io/milvus/v2/BaseTest.java @@ -0,0 +1,137 @@ +package io.milvus.v2; + +import io.milvus.grpc.*; +import io.milvus.v2.client.MilvusClientV2; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; + +import java.util.Collections; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class BaseTest { + @InjectMocks + public MilvusClientV2 client_v2 = new MilvusClientV2(null);; + @Mock + protected MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub; + + @BeforeEach + public void setUp() { + client_v2.setBlockingStub(blockingStub); + + Status successStatus = Status.newBuilder().setCode(0).build(); + BoolResponse trueResponse = BoolResponse.newBuilder().setStatus(successStatus).setValue(Boolean.TRUE).build(); + + CollectionSchema collectionSchema = CollectionSchema.newBuilder() + .setDescription("test") + .addFields(FieldSchema.newBuilder() + .setName("id") + .setDataType(DataType.Int64) + .setIsPrimaryKey(Boolean.TRUE) + .setAutoID(Boolean.FALSE) + .build()) + .addFields(FieldSchema.newBuilder() + .setName("vector") + .setDataType(DataType.FloatVector) + .addTypeParams(KeyValuePair.newBuilder().setKey("dim").setValue("2").build()) + .setIsPrimaryKey(Boolean.FALSE) + .setAutoID(Boolean.FALSE) + .build()) + .setEnableDynamicField(Boolean.FALSE) + .build(); + DescribeCollectionResponse describeCollectionResponse = DescribeCollectionResponse.newBuilder() + .setStatus(successStatus) + .setCollectionName("test") + .setSchema(collectionSchema) + .setNumPartitions(1) + .setCreatedUtcTimestamp(0) + .build(); + + IndexDescription index = IndexDescription.newBuilder() + .setIndexName("test") + .setFieldName("vector") + .addParams(KeyValuePair.newBuilder() + .setKey("index_type") + .setValue("IVF_FLAT") + .build()) + .addParams(KeyValuePair.newBuilder() + .setKey("metric_type") + .setValue("L2") + .build()) + .build(); + DescribeIndexResponse describeIndexResponse = DescribeIndexResponse.newBuilder() + .setStatus(successStatus) + .addIndexDescriptions(index) + .build(); + when(blockingStub.listDatabases(any())).thenReturn(ListDatabasesResponse.newBuilder().setStatus(successStatus).addDbNames("default").build()); + // collection api + when(blockingStub.showCollections(any(ShowCollectionsRequest.class))).thenReturn(ShowCollectionsResponse.newBuilder().setStatus(successStatus).addAllCollectionNames(Collections.singletonList("test")).build()); + when(blockingStub.createCollection(any(CreateCollectionRequest.class))).thenReturn(successStatus); + when(blockingStub.loadCollection(any())).thenReturn(successStatus); + when(blockingStub.releaseCollection(any())).thenReturn(successStatus); + when(blockingStub.dropCollection(any())).thenReturn(successStatus); + when(blockingStub.hasCollection(any())).thenReturn(trueResponse); + when(blockingStub.describeCollection(any())).thenReturn(describeCollectionResponse); + when(blockingStub.renameCollection(any())).thenReturn(successStatus); + when(blockingStub.getLoadState(any())).thenReturn(GetLoadStateResponse.newBuilder().setStatus(successStatus).build()); + + // index api + when(blockingStub.createIndex(any())).thenReturn(successStatus); + when(blockingStub.describeIndex(any())).thenReturn(describeIndexResponse); + when(blockingStub.dropIndex(any())).thenReturn(successStatus); + + //vector api + when(blockingStub.insert(any())).thenReturn(MutationResult.newBuilder().build()); + when(blockingStub.upsert(any())).thenReturn(MutationResult.newBuilder().build()); + when(blockingStub.query(any())).thenReturn(QueryResults.newBuilder().build()); + when(blockingStub.delete(any())).thenReturn(MutationResult.newBuilder().build()); + SearchResults searchResults = SearchResults.newBuilder() + .setResults(SearchResultData.newBuilder().addScores(1L).addTopks(0L).build()) + .build(); + when(blockingStub.search(any())).thenReturn(searchResults); + + // partition api + when(blockingStub.createPartition(any())).thenReturn(successStatus); + when(blockingStub.dropPartition(any())).thenReturn(successStatus); + when(blockingStub.hasPartition(any())).thenReturn(trueResponse); + when(blockingStub.showPartitions(any())).thenReturn(ShowPartitionsResponse.newBuilder().setStatus(successStatus).addPartitionNames("test").build()); + when(blockingStub.loadPartitions(any())).thenReturn(successStatus); + when(blockingStub.releasePartitions(any())).thenReturn(successStatus); + + // role api + when(blockingStub.createRole(any())).thenReturn(successStatus); + when(blockingStub.dropRole(any())).thenReturn(successStatus); + when(blockingStub.selectRole(any())).thenReturn(SelectRoleResponse.newBuilder().setStatus(successStatus).addResults(RoleResult.newBuilder().setRole(RoleEntity.newBuilder().setName("role_test").build()).build()).build()); + when(blockingStub.selectGrant(any())).thenReturn(SelectGrantResponse.newBuilder().setStatus(successStatus).addEntities(GrantEntity.newBuilder().setDbName("test").setObjectName("test").setObject(ObjectEntity.newBuilder().setName("test").build()).build()).build()); + + when(blockingStub.operatePrivilege(any())).thenReturn(successStatus); + when(blockingStub.operateUserRole(any())).thenReturn(successStatus); + + // user api + when(blockingStub.listCredUsers(any())).thenReturn(ListCredUsersResponse.newBuilder().addUsernames("user_test").build()); + when(blockingStub.createCredential(any())).thenReturn(successStatus); + when(blockingStub.updateCredential(any())).thenReturn(successStatus); + when(blockingStub.deleteCredential(any())).thenReturn(successStatus); + when(blockingStub.selectUser(any())).thenReturn(SelectUserResponse.newBuilder().setStatus(successStatus).addResults(UserResult.newBuilder().setUser(UserEntity.newBuilder().setName("user_test").build()).build()).build()); + + // utility api + when(blockingStub.flush(any())).thenReturn(FlushResponse.newBuilder().setStatus(successStatus).build()); + when(blockingStub.createAlias(any())).thenReturn(successStatus); + when(blockingStub.dropAlias(any())).thenReturn(successStatus); + when(blockingStub.alterAlias(any())).thenReturn(successStatus); + when(blockingStub.describeAlias(any())).thenReturn(DescribeAliasResponse.newBuilder().setStatus(successStatus).build()); + } + @AfterEach + public void tearDown() throws InterruptedException { + client_v2.close(3); + } +} diff --git a/src/test/java/io/milvus/v2/client/MilvusClientV2Test.java b/src/test/java/io/milvus/v2/client/MilvusClientV2Test.java new file mode 100644 index 000000000..2b114d840 --- /dev/null +++ b/src/test/java/io/milvus/v2/client/MilvusClientV2Test.java @@ -0,0 +1,22 @@ +package io.milvus.v2.client; + +import io.milvus.v2.BaseTest; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class MilvusClientV2Test extends BaseTest { + + @Test + void testMilvusClientV2() { + } + @Test + void testUseDatabase() { + try { + client_v2.useDatabase("test"); + }catch (Exception e) { + Assertions.assertEquals("Database test not exist", e.getMessage()); + } + + } + +} diff --git a/src/test/java/io/milvus/v2/service/collection/CollectionTest.java b/src/test/java/io/milvus/v2/service/collection/CollectionTest.java new file mode 100644 index 000000000..984902a2b --- /dev/null +++ b/src/test/java/io/milvus/v2/service/collection/CollectionTest.java @@ -0,0 +1,162 @@ +package io.milvus.v2.service.collection; + +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.common.DataType; +import io.milvus.v2.BaseTest; +import io.milvus.v2.service.collection.request.*; +import io.milvus.v2.service.collection.response.DescribeCollectionResp; +import io.milvus.v2.service.collection.response.ListCollectionsResp; +import io.milvus.v2.common.IndexParam; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.mockito.ArgumentMatchers.any; + +class CollectionTest extends BaseTest { + Logger logger = LoggerFactory.getLogger(CollectionTest.class); + + @Test + void testListCollections() { + R a = client_v2.listCollections(); + + logger.info("resp: {}", a.getData()); + Assertions.assertEquals(R.Status.Success.getCode(), a.getStatus()); + Assertions.assertEquals("test", a.getData().getCollectionNames().get(0)); + } + + @Test + void testCreateCollection() { + CreateCollectionReq req = CreateCollectionReq.builder() + .collectionName("test2") + .dimension(2) + .build(); + R resp = client_v2.createCollection(req); + Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus()); + } + + @Test + void testCreateCollectionWithSchema() { + List fields = new ArrayList<>(); + CreateCollectionWithSchemaReq.FieldSchema idSchema = CreateCollectionWithSchemaReq.FieldSchema.builder() + .name("id") + .dataType(DataType.Int64) + .isPrimaryKey(Boolean.TRUE) + .autoID(Boolean.FALSE) + .build(); + CreateCollectionWithSchemaReq.FieldSchema metaSchema = CreateCollectionWithSchemaReq.FieldSchema.builder() + .name("meta") + .dataType(DataType.VarChar) + .build(); + CreateCollectionWithSchemaReq.FieldSchema vectorSchema = CreateCollectionWithSchemaReq.FieldSchema.builder() + .name("vector") + .dataType(DataType.FloatVector) + .dimension(2) + .build(); + + fields.add(idSchema); + fields.add(vectorSchema); + fields.add(metaSchema); + + CreateCollectionWithSchemaReq.CollectionSchema collectionSchema = CreateCollectionWithSchemaReq.CollectionSchema.builder() + .fieldSchemaList(fields) + .enableDynamicField(Boolean.TRUE) + .build(); + + IndexParam indexParam = IndexParam.builder() + .fieldName("vector") + .metricType(IndexParam.MetricType.L2) + .indexType(IndexParam.IndexType.AUTOINDEX) + .build(); + + CreateCollectionWithSchemaReq request = CreateCollectionWithSchemaReq.builder() + .collectionName("test") + .collectionSchema(collectionSchema) + .indexParams(Collections.singletonList(indexParam)) + .build(); + R resp = client_v2.createCollectionWithSchema(request); + Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus()); + } + + @Test + void testDropCollection() { + DropCollectionReq req = DropCollectionReq.builder() + .collectionName("test") + .build(); + R resp = client_v2.dropCollection(req); + Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus()); + } + + @Test + void testHasCollection() { + HasCollectionReq req = HasCollectionReq.builder() + .collectionName("test") + .build(); + R resp = client_v2.hasCollection(req); + logger.info("resp: {}", resp.getData()); + Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus()); + } + @Test + void testDescribeCollection() { + DescribeCollectionReq req = DescribeCollectionReq.builder() + .collectionName("test2") + .build(); + R resp = client_v2.describeCollection(req); + logger.info("resp: {}", resp); + Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus()); + } + + @Test + void testRenameCollection() { + RenameCollectionReq req = RenameCollectionReq.builder() + .collectionName("test2") + .newCollectionName("test") + .build(); + R resp = client_v2.renameCollection(req); + Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus()); + } + + @Test + void testLoadCollection() { + LoadCollectionReq req = LoadCollectionReq.builder() + .collectionName("test") + .build(); + R resp = client_v2.loadCollection(req); + Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus()); + } + + @Test + void testReleaseCollection() { + ReleaseCollectionReq req = ReleaseCollectionReq.builder() + .collectionName("test") + .build(); + R resp = client_v2.releaseCollection(req); + Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus()); + } + + @Test + void testGetLoadState() { + GetLoadStateReq req = GetLoadStateReq.builder() + .collectionName("test") + .build(); + R resp = client_v2.getLoadState(req); + logger.info("resp: {}", resp.getData()); + Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus()); + } + +// @Test +// void testGetCollectionStats() { +// GetCollectionStatsReq req = GetCollectionStatsReq.builder() +// .collectionName("test") +// .build(); +// R resp = clientv_2.getCollectionStats(req); +// logger.info("resp: {}", resp); +// Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus()); +// } +} \ No newline at end of file diff --git a/src/test/java/io/milvus/v2/service/index/IndexTest.java b/src/test/java/io/milvus/v2/service/index/IndexTest.java new file mode 100644 index 000000000..e554ac5be --- /dev/null +++ b/src/test/java/io/milvus/v2/service/index/IndexTest.java @@ -0,0 +1,51 @@ +package io.milvus.v2.service.index; + +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.BaseTest; +import io.milvus.v2.service.index.request.CreateIndexReq; +import io.milvus.v2.service.index.request.DescribeIndexReq; +import io.milvus.v2.service.index.request.DropIndexReq; +import io.milvus.v2.service.index.response.DescribeIndexResp; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class IndexTest extends BaseTest { + Logger logger = LoggerFactory.getLogger(IndexTest.class); + + @Test + void testCreateIndex() { + IndexParam indexParam = IndexParam.builder() + .metricType(IndexParam.MetricType.COSINE) + .indexType(IndexParam.IndexType.AUTOINDEX) + .fieldName("vector") + .build(); + + CreateIndexReq createIndexReq = CreateIndexReq.builder() + .collectionName("test") + .indexParam(indexParam) + .build(); + client_v2.createIndex(createIndexReq); + } + @Test + void testDescribeIndex() { + DescribeIndexReq describeIndexReq = DescribeIndexReq.builder() + .collectionName("test") + .fieldName("vector") + .build(); + R responseR = client_v2.describeIndex(describeIndexReq); + logger.info(responseR.toString()); + } + @Test + void testDropIndex() { + DropIndexReq dropIndexReq = DropIndexReq.builder() + .collectionName("test") + .fieldName("vector") + .build(); + R resp = client_v2.dropIndex(dropIndexReq); + Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus()); + } +} \ No newline at end of file diff --git a/src/test/java/io/milvus/v2/service/partition/PartitionTest.java b/src/test/java/io/milvus/v2/service/partition/PartitionTest.java new file mode 100644 index 000000000..c12653be7 --- /dev/null +++ b/src/test/java/io/milvus/v2/service/partition/PartitionTest.java @@ -0,0 +1,87 @@ +package io.milvus.v2.service.partition; + +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.BaseTest; +import io.milvus.v2.service.partition.request.*; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; + +class PartitionTest extends BaseTest { + Logger logger = LoggerFactory.getLogger(PartitionTest.class); + + @Test + void testCreatePartition() { + CreatePartitionReq req = CreatePartitionReq.builder() + .collectionName("test") + .partitionName("test") + .build(); + R res = client_v2.createPartition(req); + logger.info("resp: {}", res); + Assertions.assertEquals(0, res.getStatus()); + } + + @Test + void testDropPartition() { + DropPartitionReq req = DropPartitionReq.builder() + .collectionName("test") + .partitionName("test") + .build(); + R res = client_v2.dropPartition(req); + logger.info("resp: {}", res); + Assertions.assertEquals(0, res.getStatus()); + } + + @Test + void testHasPartition() { + HasPartitionReq req = HasPartitionReq.builder() + .collectionName("test") + .partitionName("_default") + .build(); + R res = client_v2.hasPartition(req); + logger.info("resp: {}", res); + Assertions.assertEquals(0, res.getStatus()); + } + + @Test + void testListPartitions() { + ListPartitionsReq req = ListPartitionsReq.builder() + .collectionName("test") + .build(); + R> res = client_v2.listPartitions(req); + logger.info("resp: {}", res); + Assertions.assertEquals(0, res.getStatus()); + } + + @Test + void testLoadPartition() { + List partitionNames = new ArrayList<>(); + partitionNames.add("test"); + LoadPartitionsReq req = LoadPartitionsReq.builder() + .collectionName("test") + .partitionNames(partitionNames) + .build(); + R res = client_v2.loadPartitions(req); + logger.info("resp: {}", res); + Assertions.assertEquals(0, res.getStatus()); + } + + @Test + void testReleasePartition() { + List partitionNames = new ArrayList<>(); + partitionNames.add("test"); + + ReleasePartitionsReq req = ReleasePartitionsReq.builder() + .collectionName("test") + .partitionNames(partitionNames) + .build(); + R res = client_v2.releasePartitions(req); + logger.info("resp: {}", res); + Assertions.assertEquals(0, res.getStatus()); + } +} \ No newline at end of file diff --git a/src/test/java/io/milvus/v2/service/rbac/RoleTest.java b/src/test/java/io/milvus/v2/service/rbac/RoleTest.java new file mode 100644 index 000000000..ba6a90d2b --- /dev/null +++ b/src/test/java/io/milvus/v2/service/rbac/RoleTest.java @@ -0,0 +1,103 @@ +package io.milvus.v2.service.rbac; + +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.BaseTest; +import io.milvus.v2.service.rbac.request.*; +import io.milvus.v2.service.rbac.response.DescribeRoleResp; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +class RoleTest extends BaseTest { + + Logger logger = LoggerFactory.getLogger(RoleTest.class); + + @Test + void testListRoles() { + R> roles = client_v2.listRoles(); + logger.info(roles.toString()); + Assertions.assertEquals(roles.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testCreateRole() { + CreateRoleReq request = CreateRoleReq.builder() + .roleName("test") + .build(); + R statusR = client_v2.createRole(request); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testDescribeRole() { + DescribeRoleReq describeRoleReq = DescribeRoleReq.builder() + .roleName("db_rw") + .build(); + R statusR = client_v2.describeRole(describeRoleReq); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testDropRole() { + DropRoleReq request = DropRoleReq.builder() + .roleName("test") + .build(); + R statusR = client_v2.dropRole(request); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testGrantPrivilege() { + GrantPrivilegeReq request = GrantPrivilegeReq.builder() + .roleName("db_rw") + .objectName("") + .objectType("") + .privilege("") + .build(); + R statusR = client_v2.grantPrivilege(request); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testRevokePrivilege() { + RevokePrivilegeReq request = RevokePrivilegeReq.builder() + .roleName("db_rw") + .objectName("") + .objectType("") + .privilege("") + .build(); + R statusR = client_v2.revokePrivilege(request); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testGrantRole() { + GrantRoleReq request = GrantRoleReq.builder() + .roleName("db_ro") + .userName("test") + .build(); + R statusR = client_v2.grantRole(request); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testRevokeRole() { + RevokeRoleReq request = RevokeRoleReq.builder() + .roleName("db_ro") + .userName("test") + .build(); + R statusR = client_v2.revokeRole(request); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } +} \ No newline at end of file diff --git a/src/test/java/io/milvus/v2/service/rbac/UserTest.java b/src/test/java/io/milvus/v2/service/rbac/UserTest.java new file mode 100644 index 000000000..4178a9690 --- /dev/null +++ b/src/test/java/io/milvus/v2/service/rbac/UserTest.java @@ -0,0 +1,67 @@ +package io.milvus.v2.service.rbac; + +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.BaseTest; +import io.milvus.v2.service.rbac.request.*; +import io.milvus.v2.service.rbac.response.DescribeUserResp; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +class UserTest extends BaseTest { + Logger logger = LoggerFactory.getLogger(UserTest.class); + + @Test + void listUsers() { + R> resp = client_v2.listUsers(); + logger.info("resp: {}", resp); + Assertions.assertEquals(resp.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testDescribeUser() { + DescribeUserReq req = DescribeUserReq.builder() + .userName("test") + .build(); + R resp = client_v2.describeUser(req); + logger.info("resp: {}", resp); + Assertions.assertEquals(resp.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testCreateUser() { + CreateUserReq req = CreateUserReq.builder() + .userName("test") + .password("Zilliz@2023") + .build(); + R resp = client_v2.createUser(req); + logger.info("resp: {}", resp); + Assertions.assertEquals(resp.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testUpdatePassword() { + UpdatePasswordReq req = UpdatePasswordReq.builder() + .userName("test") + .password("Zilliz@2023") + .newPassword("Zilliz@2024") + .build(); + R resp = client_v2.updatePassword(req); + logger.info("resp: {}", resp); + Assertions.assertEquals(resp.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testDropUser() { + DropUserReq req = DropUserReq.builder() + .userName("test") + .build(); + R resp = client_v2.dropUser(req); + logger.info("resp: {}", resp); + Assertions.assertEquals(resp.getStatus(), R.Status.Success.getCode()); + } +} \ No newline at end of file diff --git a/src/test/java/io/milvus/v2/service/utility/UtilityTest.java b/src/test/java/io/milvus/v2/service/utility/UtilityTest.java new file mode 100644 index 000000000..d5bf0f3ae --- /dev/null +++ b/src/test/java/io/milvus/v2/service/utility/UtilityTest.java @@ -0,0 +1,70 @@ +package io.milvus.v2.service.utility; + +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.service.utility.request.FlushReq; +import io.milvus.v2.BaseTest; +import io.milvus.v2.service.utility.request.AlterAliasReq; +import io.milvus.v2.service.utility.request.CreateAliasReq; +import io.milvus.v2.service.utility.request.DropAliasReq; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.junit.jupiter.api.Assertions.*; + +class UtilityTest extends BaseTest { + Logger logger = LoggerFactory.getLogger(UtilityTest.class); + + @Test + void testFlush() { + FlushReq req = FlushReq.builder() + .collectionName("test") + .build(); + R statusR = client_v2.flush(req); + logger.info("resp: {}", statusR.getData()); + assertEquals(R.Status.Success.getCode(), statusR.getStatus()); + } + + @Test + void testCreateAlias() { + CreateAliasReq req = CreateAliasReq.builder() + .collectionName("test") + .alias("test_alias") + .build(); + R statusR = client_v2.createAlias(req); + logger.info("resp: {}", statusR.getData()); + assertEquals(R.Status.Success.getCode(), statusR.getStatus()); + } + @Test + void testDropAlias() { + DropAliasReq req = DropAliasReq.builder() + .alias("test_alias") + .build(); + R statusR = client_v2.dropAlias(req); + logger.info("resp: {}", statusR.getData()); + assertEquals(R.Status.Success.getCode(), statusR.getStatus()); + } + @Test + void testAlterAlias() { + AlterAliasReq req = AlterAliasReq.builder() + .collectionName("test") + .alias("test_alias") + .build(); + R statusR = client_v2.alterAlias(req); + logger.info("resp: {}", statusR.getData()); + assertEquals(R.Status.Success.getCode(), statusR.getStatus()); + } +// @Test +// void describeAlias() { +// R statusR = clientv_2.describeAlias("test_alias"); +// logger.info("resp: {}", statusR.getData()); +// assertEquals(R.Status.Success.getCode(), statusR.getStatus()); +// } +// @Test +// void listAliases() { +// R statusR = clientv_2.listAliases(); +// logger.info("resp: {}", statusR.getData()); +// assertEquals(R.Status.Success.getCode(), statusR.getStatus()); +// } +} \ No newline at end of file diff --git a/src/test/java/io/milvus/v2/service/vector/VectorTest.java b/src/test/java/io/milvus/v2/service/vector/VectorTest.java new file mode 100644 index 000000000..b9063f262 --- /dev/null +++ b/src/test/java/io/milvus/v2/service/vector/VectorTest.java @@ -0,0 +1,120 @@ +package io.milvus.v2.service.vector; + +import com.alibaba.fastjson.JSONObject; +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.v2.BaseTest; +import io.milvus.v2.service.vector.request.*; +import io.milvus.v2.service.vector.response.GetResp; +import io.milvus.v2.service.vector.response.QueryResp; +import io.milvus.v2.service.vector.response.SearchResp; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.*; + +class VectorTest extends BaseTest { + + Logger logger = LoggerFactory.getLogger(VectorTest.class); + + @Test + void testInsert() { + JSONObject vector = new JSONObject(); + List vectorList = new ArrayList<>(); + vectorList.add(1.0f); + vectorList.add(2.0f); + vector.put("vector", vectorList); + vector.put("id", 0L); + + InsertReq request = InsertReq.builder() + .collectionName("test2") + .insertData(Collections.singletonList(vector)) + .build(); + R statusR = client_v2.insert(request); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testUpsert() { + + JSONObject jsonObject = new JSONObject(); + List vectorList = new ArrayList<>(); + vectorList.add(2.0f); + vectorList.add(3.0f); + jsonObject.put("vector", vectorList); + //jsonObject.put("id", 0L); + UpsertReq request = UpsertReq.builder() + .collectionName("test") + .upsertData(Collections.singletonList(jsonObject)) + .build(); + + R statusR = client_v2.upsert(request); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testQuery() { + QueryReq req = QueryReq.builder() + .collectionName("test2") + .expr("") + .limit(10) + //.outputFields(Collections.singletonList("count(*)")) + .build(); + R resultsR = client_v2.query(req); + + logger.info(resultsR.toString()); + Assertions.assertEquals(resultsR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testSearch() { + List vectorList = new ArrayList<>(); + vectorList.add(1.0f); + vectorList.add(2.0f); + SearchReq request = SearchReq.builder() + .collectionName("test2") + .vectors(Collections.singletonList(vectorList)) + .topK(10) + .offset(0L) + .build(); + R statusR = client_v2.search(request); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testDelete() { + DeleteReq request = DeleteReq.builder() + .collectionName("test") + .expr("id > 0") + .build(); + R statusR = client_v2.delete(request); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testDeleteById(){ + DeleteReq request = DeleteReq.builder() + .collectionName("test") + .ids(Collections.singletonList("0")) + .build(); + R statusR = client_v2.delete(request); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } + + @Test + void testGet() { + GetReq request = GetReq.builder() + .collectionName("test2") + .ids(Collections.singletonList("447198483337881033")) + .build(); + R statusR = client_v2.get(request); + logger.info(statusR.toString()); + Assertions.assertEquals(statusR.getStatus(), R.Status.Success.getCode()); + } +} \ No newline at end of file