From 235f170d8dffeb76fea9cc2416a4c085ab5adcac Mon Sep 17 00:00:00 2001 From: groot Date: Tue, 26 Nov 2024 16:10:36 +0800 Subject: [PATCH] BulkWriter supports Json/CSV (#1193) Signed-off-by: yhmo --- .../main/java/io/milvus/v2/SimpleExample.java | 58 ++--- .../java/io/milvus/bulkwriter/Buffer.java | 110 +++++++- .../io/milvus/bulkwriter/LocalBulkWriter.java | 28 ++- .../bulkwriter/LocalBulkWriterParam.java | 11 + .../milvus/bulkwriter/RemoteBulkWriter.java | 8 +- .../bulkwriter/RemoteBulkWriterParam.java | 11 + .../common/clientenum/BulkFileType.java | 2 + src/test/java/io/milvus/TestUtils.java | 162 ++++++++++++ .../io/milvus/bulkwriter/BulkWriterTest.java | 143 ++++++++++- .../milvus/client/MilvusClientDockerTest.java | 234 ++++-------------- .../client/MilvusMultiClientDockerTest.java | 64 +---- .../v2/client/MilvusClientV2DockerTest.java | 222 ++++------------- 12 files changed, 582 insertions(+), 471 deletions(-) create mode 100644 src/test/java/io/milvus/TestUtils.java diff --git a/examples/main/java/io/milvus/v2/SimpleExample.java b/examples/main/java/io/milvus/v2/SimpleExample.java index e52eadd65..07c9b29a6 100644 --- a/examples/main/java/io/milvus/v2/SimpleExample.java +++ b/examples/main/java/io/milvus/v2/SimpleExample.java @@ -103,35 +103,37 @@ public static void main(String[] args) { System.out.printf("ID: %d, Score: %f, %s\n", (long)result.getId(), result.getScore(), result.getEntity().toString()); } } + // search with template expression - // Map> expressionTemplateValues = new HashMap<>(); - // Map params = new HashMap<>(); - // params.put("max", 10); - // expressionTemplateValues.put("id < {max}", params); - // - // List list = Arrays.asList(1, 2, 3); - // Map params2 = new HashMap<>(); - // params2.put("list", list); - // expressionTemplateValues.put("id in {list}", params2); - // - // expressionTemplateValues.forEach((key, value) -> { - // SearchReq request = SearchReq.builder() - // .collectionName(collectionName) - // .data(Collections.singletonList(new FloatVec(new float[]{1.0f, 1.0f, 1.0f, 1.0f}))) - // .topK(10) - // .filter(key) - // .filterTemplateValues(value) - // .outputFields(Collections.singletonList("*")) - // .build(); - // SearchResp statusR = client.search(request); - // List> searchResults2 = statusR.getSearchResults(); - // System.out.println("\nSearch results:"); - // for (List results : searchResults2) { - // for (SearchResp.SearchResult result : results) { - // System.out.printf("ID: %d, Score: %f, %s\n", (long)result.getId(), result.getScore(), result.getEntity().toString()); - // } - // } - // }); + Map> expressionTemplateValues = new HashMap<>(); + Map params = new HashMap<>(); + params.put("max", 10); + expressionTemplateValues.put("id < {max}", params); + + List list = Arrays.asList(1, 2, 3); + Map params2 = new HashMap<>(); + params2.put("list", list); + expressionTemplateValues.put("id in {list}", params2); + + expressionTemplateValues.forEach((key, value) -> { + SearchReq request = SearchReq.builder() + .collectionName(collectionName) + .data(Collections.singletonList(new FloatVec(new float[]{1.0f, 1.0f, 1.0f, 1.0f}))) + .topK(10) + .filter(key) + .filterTemplateValues(value) + .outputFields(Collections.singletonList("*")) + .build(); + SearchResp statusR = client.search(request); + List> searchResults2 = statusR.getSearchResults(); + System.out.println("\nSearch with template results:"); + for (List results : searchResults2) { + for (SearchResp.SearchResult result : results) { + System.out.printf("ID: %d, Score: %f, %s\n", (long)result.getId(), result.getScore(), result.getEntity().toString()); + } + } + }); + client.close(); } } diff --git a/src/main/java/io/milvus/bulkwriter/Buffer.java b/src/main/java/io/milvus/bulkwriter/Buffer.java index aef459827..7c3dd1fc3 100644 --- a/src/main/java/io/milvus/bulkwriter/Buffer.java +++ b/src/main/java/io/milvus/bulkwriter/Buffer.java @@ -20,6 +20,8 @@ package io.milvus.bulkwriter; import com.google.common.collect.Lists; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; import io.milvus.bulkwriter.common.clientenum.BulkFileType; import io.milvus.common.utils.ExceptionUtils; import io.milvus.bulkwriter.common.utils.ParquetUtils; @@ -39,12 +41,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; +import java.io.*; import java.nio.ByteBuffer; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.SortedMap; +import java.util.*; import java.util.stream.Collectors; import static io.milvus.param.Constant.DYNAMIC_FIELD_NAME; @@ -62,8 +61,8 @@ public Buffer(CollectionSchemaParam collectionSchema, BulkFileType fileType) { this.collectionSchema = collectionSchema; this.fileType = fileType; - buffer = new HashMap<>(); - fields = new HashMap<>(); + buffer = new LinkedHashMap<>(); + fields = new LinkedHashMap<>(); for (FieldType fieldType : collectionSchema.getFieldTypes()) { if (fieldType.isPrimaryKey() && fieldType.isAutoID()) @@ -103,7 +102,7 @@ public void appendRow(Map row) { } // verify row count of fields are equal - public List persist(String localPath, Integer bufferSize, Integer bufferRowCount) { + public List persist(String localPath, Map config) throws IOException { int rowCount = -1; for (String key : buffer.keySet()) { if (rowCount < 0) { @@ -116,13 +115,21 @@ public List persist(String localPath, Integer bufferSize, Integer buffer // output files if (fileType == BulkFileType.PARQUET) { + Integer bufferSize = (Integer) config.get("bufferSize"); + Integer bufferRowCount = (Integer) config.get("bufferRowCount"); return persistParquet(localPath, bufferSize, bufferRowCount); + } else if (fileType == BulkFileType.JSON) { + return persistJSON(localPath); + } else if (fileType == BulkFileType.CSV) { + String separator = (String)config.getOrDefault("sep", "\t"); + String nullKey = (String)config.getOrDefault("nullkey", ""); + return persistCSV(localPath, separator, nullKey); } ExceptionUtils.throwUnExpectedException("Unsupported file type: " + fileType); return null; } - private List persistParquet(String localPath, Integer bufferSize, Integer bufferRowCount) { + private List persistParquet(String localPath, Integer bufferSize, Integer bufferRowCount) throws IOException { String filePath = localPath + ".parquet"; // calculate a proper row group size @@ -178,6 +185,7 @@ private List persistParquet(String localPath, Integer bufferSize, Intege } } catch (IOException e) { e.printStackTrace(); + throw e; } String msg = String.format("Successfully persist file %s, total size: %s, row count: %s, row group size: %s", @@ -186,6 +194,90 @@ private List persistParquet(String localPath, Integer bufferSize, Intege return Lists.newArrayList(filePath); } + private List persistJSON(String localPath) throws IOException { + String filePath = localPath + ".json"; + + Gson gson = new GsonBuilder().serializeNulls().create(); + List> data = new ArrayList<>(); + + List fieldNameList = Lists.newArrayList(buffer.keySet()); + int size = buffer.get(fieldNameList.get(0)).size(); + for (int i = 0; i < size; ++i) { + Map row = new HashMap<>(); + for (String fieldName : fieldNameList) { + if (buffer.get(fieldName).get(i) instanceof ByteBuffer) { + row.put(fieldName, ((ByteBuffer)buffer.get(fieldName).get(i)).array()); + } else { + row.put(fieldName, buffer.get(fieldName).get(i)); + } + } + data.add(row); + } + + try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(filePath))) { + bufferedWriter.write("[\n"); + for (int i = 0; i < data.size(); i++) { + String json = gson.toJson(data.get(i)); + if (i != data.size()-1) { + json += ","; + } + bufferedWriter.write(json); + bufferedWriter.newLine(); + } + bufferedWriter.write("]\n"); + } catch (IOException e) { + e.printStackTrace(); + throw e; + } + + return Lists.newArrayList(filePath); + } + + private List persistCSV(String localPath, String separator, String nullKey) throws IOException { + String filePath = localPath + ".csv"; + + Gson gson = new GsonBuilder().serializeNulls().create(); + List fieldNameList = Lists.newArrayList(buffer.keySet()); + try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(filePath))) { + bufferedWriter.write(String.join(separator, fieldNameList)); + bufferedWriter.newLine(); + int size = buffer.get(fieldNameList.get(0)).size(); + for (int i = 0; i < size; ++i) { + List values = new ArrayList<>(); + for (String fieldName : fieldNameList) { + Object val = buffer.get(fieldName).get(i); + String strVal = ""; + if (val == null) { + strVal = nullKey; + } else if (val instanceof ByteBuffer) { + strVal = Arrays.toString(((ByteBuffer) val).array()); + } else if (val instanceof List || val instanceof Map) { + strVal = gson.toJson(val); // server-side is using json to parse array field and vector field + } else { + strVal = val.toString(); + } + + // CSV format, all the single quotation should be replaced by double quotation + if (strVal.startsWith("\"") && strVal.endsWith("\"")) { + strVal = strVal.substring(1, strVal.length() - 1); + } + strVal = strVal.replace("\\\"", "\""); + strVal = strVal.replace("\"", "\"\""); + strVal = "\"" + strVal + "\""; + values.add(strVal); + } + + bufferedWriter.write(String.join(separator, values)); + bufferedWriter.newLine(); + } + } catch (IOException e) { + e.printStackTrace(); + throw e; + } + + return Lists.newArrayList(filePath); + } + private void appendGroup(Group group, String paramName, Object value, FieldType fieldType) { DataType dataType = fieldType.getDataType(); switch (dataType) { diff --git a/src/main/java/io/milvus/bulkwriter/LocalBulkWriter.java b/src/main/java/io/milvus/bulkwriter/LocalBulkWriter.java index 6bcb50d04..2538f49b7 100644 --- a/src/main/java/io/milvus/bulkwriter/LocalBulkWriter.java +++ b/src/main/java/io/milvus/bulkwriter/LocalBulkWriter.java @@ -46,6 +46,7 @@ public class LocalBulkWriter extends BulkWriter implements AutoCloseable { private Map workingThread; private ReentrantLock workingThreadLock; private List> localFiles; + private final Map config; public LocalBulkWriter(LocalBulkWriterParam bulkWriterParam) throws IOException { super(bulkWriterParam.getCollectionSchema(), bulkWriterParam.getChunkSize(), bulkWriterParam.getFileType()); @@ -54,16 +55,22 @@ public LocalBulkWriter(LocalBulkWriterParam bulkWriterParam) throws IOException this.workingThreadLock = new ReentrantLock(); this.workingThread = new HashMap<>(); this.localFiles = Lists.newArrayList(); + this.config = bulkWriterParam.getConfig(); this.makeDir(); } - protected LocalBulkWriter(CollectionSchemaParam collectionSchema, int chunkSize, BulkFileType fileType, String localPath) throws IOException { + protected LocalBulkWriter(CollectionSchemaParam collectionSchema, + int chunkSize, + BulkFileType fileType, + String localPath, + Map config) throws IOException { super(collectionSchema, chunkSize, fileType); this.localPath = localPath; this.uuid = UUID.randomUUID().toString(); this.workingThreadLock = new ReentrantLock(); this.workingThread = new HashMap<>(); this.localFiles = Lists.newArrayList(); + this.config = config; this.makeDir(); } @@ -84,7 +91,7 @@ public void appendRow(JsonObject rowData) throws IOException, InterruptedExcepti public void commit(boolean async) throws InterruptedException { // _async=True, the flush thread is asynchronously - while (workingThread.size() > 0) { + while (!workingThread.isEmpty()) { String msg = String.format("Previous flush action is not finished, %s is waiting...", Thread.currentThread().getName()); logger.info(msg); TimeUnit.SECONDS.sleep(5); @@ -116,13 +123,20 @@ private void flush(Integer bufferSize, Integer bufferRowCount) { java.nio.file.Path path = Paths.get(localPath); java.nio.file.Path flushDirPath = path.resolve(String.valueOf(flushCount)); + Map config = new HashMap<>(this.config); + config.put("bufferSize", bufferSize); + config.put("bufferRowCount", bufferRowCount); Buffer oldBuffer = super.newBuffer(); if (oldBuffer.getRowCount() > 0) { - List fileList = oldBuffer.persist( - flushDirPath.toString(), bufferSize, bufferRowCount - ); - localFiles.add(fileList); - callBack(fileList); + try { + List fileList = oldBuffer.persist(flushDirPath.toString(), config); + localFiles.add(fileList); + callBack(fileList); + } catch (IOException e) { + // this function is running in a thread + // TODO: interrupt main thread if failed to persist file + logger.error(e.getMessage()); + } } workingThread.remove(Thread.currentThread().getName()); String msg = String.format("Flush thread done, name: %s", Thread.currentThread().getName()); diff --git a/src/main/java/io/milvus/bulkwriter/LocalBulkWriterParam.java b/src/main/java/io/milvus/bulkwriter/LocalBulkWriterParam.java index 0f2f72b4e..7c2426641 100644 --- a/src/main/java/io/milvus/bulkwriter/LocalBulkWriterParam.java +++ b/src/main/java/io/milvus/bulkwriter/LocalBulkWriterParam.java @@ -29,6 +29,9 @@ import lombok.NonNull; import lombok.ToString; +import java.util.HashMap; +import java.util.Map; + /** * Parameters for bulkWriter interface. */ @@ -39,12 +42,14 @@ public class LocalBulkWriterParam { private final String localPath; private final int chunkSize; private final BulkFileType fileType; + private final Map config; private LocalBulkWriterParam(@NonNull Builder builder) { this.collectionSchema = builder.collectionSchema; this.localPath = builder.localPath; this.chunkSize = builder.chunkSize; this.fileType = builder.fileType; + this.config = builder.config; } public static Builder newBuilder() { @@ -59,6 +64,7 @@ public static final class Builder { private String localPath; private int chunkSize = 128 * 1024 * 1024; private BulkFileType fileType = BulkFileType.PARQUET; + private Map config = new HashMap<>(); private Builder() { } @@ -106,6 +112,11 @@ public Builder withFileType(BulkFileType fileType) { return this; } + public Builder withConfig(String key, Object val) { + this.config.put(key, val); + return this; + } + /** * Verifies parameters and creates a new {@link LocalBulkWriterParam} instance. * diff --git a/src/main/java/io/milvus/bulkwriter/RemoteBulkWriter.java b/src/main/java/io/milvus/bulkwriter/RemoteBulkWriter.java index f9b56fb77..0e2b958bf 100644 --- a/src/main/java/io/milvus/bulkwriter/RemoteBulkWriter.java +++ b/src/main/java/io/milvus/bulkwriter/RemoteBulkWriter.java @@ -42,7 +42,9 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class RemoteBulkWriter extends LocalBulkWriter { private static final Logger logger = LoggerFactory.getLogger(RemoteBulkWriter.class); @@ -54,7 +56,11 @@ public class RemoteBulkWriter extends LocalBulkWriter { private List> remoteFiles; public RemoteBulkWriter(RemoteBulkWriterParam bulkWriterParam) throws IOException { - super(bulkWriterParam.getCollectionSchema(), bulkWriterParam.getChunkSize(), bulkWriterParam.getFileType(), generatorLocalPath()); + super(bulkWriterParam.getCollectionSchema(), + bulkWriterParam.getChunkSize(), + bulkWriterParam.getFileType(), + generatorLocalPath(), + bulkWriterParam.getConfig()); Path path = Paths.get(bulkWriterParam.getRemotePath()); Path remoteDirPath = path.resolve(getUUID()); this.remotePath = remoteDirPath.toString(); diff --git a/src/main/java/io/milvus/bulkwriter/RemoteBulkWriterParam.java b/src/main/java/io/milvus/bulkwriter/RemoteBulkWriterParam.java index 58a47e824..b724cc83d 100644 --- a/src/main/java/io/milvus/bulkwriter/RemoteBulkWriterParam.java +++ b/src/main/java/io/milvus/bulkwriter/RemoteBulkWriterParam.java @@ -31,6 +31,9 @@ import lombok.ToString; import org.jetbrains.annotations.NotNull; +import java.util.HashMap; +import java.util.Map; + /** * Parameters for bulkWriter interface. */ @@ -42,6 +45,7 @@ public class RemoteBulkWriterParam { private final String remotePath; private final int chunkSize; private final BulkFileType fileType; + private final Map config; private RemoteBulkWriterParam(@NonNull Builder builder) { this.collectionSchema = builder.collectionSchema; @@ -49,6 +53,7 @@ private RemoteBulkWriterParam(@NonNull Builder builder) { this.remotePath = builder.remotePath; this.chunkSize = builder.chunkSize; this.fileType = builder.fileType; + this.config = builder.config; } public static Builder newBuilder() { @@ -64,6 +69,7 @@ public static final class Builder { private String remotePath; private int chunkSize = 1024 * 1024 * 1024; private BulkFileType fileType = BulkFileType.PARQUET; + private Map config = new HashMap<>(); private Builder() { } @@ -116,6 +122,11 @@ public Builder withFileType(@NonNull BulkFileType fileType) { return this; } + public Builder withConfig(String key, Object val) { + this.config.put(key, val); + return this; + } + /** * Verifies parameters and creates a new {@link RemoteBulkWriterParam} instance. * diff --git a/src/main/java/io/milvus/bulkwriter/common/clientenum/BulkFileType.java b/src/main/java/io/milvus/bulkwriter/common/clientenum/BulkFileType.java index 625e596ab..aba2864b8 100644 --- a/src/main/java/io/milvus/bulkwriter/common/clientenum/BulkFileType.java +++ b/src/main/java/io/milvus/bulkwriter/common/clientenum/BulkFileType.java @@ -21,6 +21,8 @@ public enum BulkFileType { PARQUET(1), + JSON(2), + CSV(3), ; private Integer code; diff --git a/src/test/java/io/milvus/TestUtils.java b/src/test/java/io/milvus/TestUtils.java new file mode 100644 index 000000000..5b0fb43ad --- /dev/null +++ b/src/test/java/io/milvus/TestUtils.java @@ -0,0 +1,162 @@ +package io.milvus; + +import io.milvus.common.utils.Float16Utils; +import io.milvus.grpc.DataType; +import io.milvus.param.collection.FieldType; +import org.junit.jupiter.api.Assertions; + +import java.nio.ByteBuffer; +import java.util.*; + +public class TestUtils { + private int dimension = 256; + private static final Random RANDOM = new Random(); + + public TestUtils(int dimension) { + this.dimension = dimension; + } + + public List generateFloatVector(int dim) { + List vector = new ArrayList<>(); + for (int i = 0; i < dim; ++i) { + vector.add(RANDOM.nextFloat()); + } + return vector; + } + + public List generateFloatVector() { + return generateFloatVector(dimension); + } + + public List> generateFloatVectors(int count) { + List> vectors = new ArrayList<>(); + for (int n = 0; n < count; ++n) { + vectors.add(generateFloatVector()); + } + + return vectors; + } + + public ByteBuffer generateBinaryVector(int dim) { + int byteCount = dim / 8; + ByteBuffer vector = ByteBuffer.allocate(byteCount); + for (int i = 0; i < byteCount; ++i) { + vector.put((byte) RANDOM.nextInt(Byte.MAX_VALUE)); + } + return vector; + } + + public ByteBuffer generateBinaryVector() { + return generateBinaryVector(dimension); + } + + public List generateBinaryVectors(int count) { + List vectors = new ArrayList<>(); + for (int n = 0; n < count; ++n) { + vectors.add(generateBinaryVector()); + } + return vectors; + + } + + public ByteBuffer generateFloat16Vector() { + List vector = generateFloatVector(); + return Float16Utils.f32VectorToFp16Buffer(vector); + } + + public List generateFloat16Vectors(int count) { + List vectors = new ArrayList<>(); + for (int n = 0; n < count; ++n) { + vectors.add(generateFloat16Vector()); + } + return vectors; + } + + public ByteBuffer generateBFloat16Vector() { + List vector = generateFloatVector(); + return Float16Utils.f32VectorToBf16Buffer(vector); + } + + public List generateBFloat16Vectors(int count) { + List vectors = new ArrayList<>(); + for (int n = 0; n < count; ++n) { + vectors.add(generateBFloat16Vector()); + } + return vectors; + } + + public SortedMap generateSparseVector() { + SortedMap sparse = new TreeMap<>(); + int dim = RANDOM.nextInt(10) + 10; + for (int i = 0; i < dim; ++i) { + sparse.put((long) RANDOM.nextInt(1000000), RANDOM.nextFloat()); + } + return sparse; + } + + public List> generateSparseVectors(int count) { + List> vectors = new ArrayList<>(); + for (int n = 0; n < count; ++n) { + vectors.add(generateSparseVector()); + } + return vectors; + } + + public List generateRandomArray(DataType eleType, int maxCapacity) { + switch (eleType) { + case Bool: { + List values = new ArrayList<>(); + for (int i = 0; i < maxCapacity; i++) { + values.add(i%10 == 0); + } + return values; + } + case Int8: + case Int16: { + List values = new ArrayList<>(); + for (int i = 0; i < maxCapacity; i++) { + values.add((short)RANDOM.nextInt(256)); + } + return values; + } + case Int32: { + List values = new ArrayList<>(); + for (int i = 0; i < maxCapacity; i++) { + values.add(RANDOM.nextInt()); + } + return values; + } + case Int64: { + List values = new ArrayList<>(); + for (int i = 0; i < maxCapacity; i++) { + values.add(RANDOM.nextLong()); + } + return values; + } + case Float: { + List values = new ArrayList<>(); + for (int i = 0; i < maxCapacity; i++) { + values.add(RANDOM.nextFloat()); + } + return values; + } + case Double: { + List values = new ArrayList<>(); + for (int i = 0; i < maxCapacity; i++) { + values.add(RANDOM.nextDouble()); + } + return values; + } + case VarChar: { + List values = new ArrayList<>(); + for (int i = 0; i < maxCapacity; i++) { + values.add(String.format("varchar_arr_%d", i)); + } + return values; + } + default: + Assertions.fail(); + } + return null; + } +} diff --git a/src/test/java/io/milvus/bulkwriter/BulkWriterTest.java b/src/test/java/io/milvus/bulkwriter/BulkWriterTest.java index 497be20a6..22d9a2513 100644 --- a/src/test/java/io/milvus/bulkwriter/BulkWriterTest.java +++ b/src/test/java/io/milvus/bulkwriter/BulkWriterTest.java @@ -19,7 +19,12 @@ package io.milvus.bulkwriter; +import com.google.gson.JsonObject; +import io.milvus.TestUtils; +import io.milvus.bulkwriter.common.clientenum.BulkFileType; +import io.milvus.bulkwriter.common.utils.GeneratorUtils; import io.milvus.bulkwriter.common.utils.V2AdapterUtils; +import io.milvus.common.utils.JsonUtils; import io.milvus.param.collection.CollectionSchemaParam; import io.milvus.param.collection.FieldType; import io.milvus.v2.common.DataType; @@ -28,19 +33,25 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Random; public class BulkWriterTest { - @Test - void testV2AdapterUtils() { + private static final int DIMENSION = 128; + private static final TestUtils utils = new TestUtils(DIMENSION); + + CreateCollectionReq.CollectionSchema buildSchema() { CreateCollectionReq.CollectionSchema schemaV2 = CreateCollectionReq.CollectionSchema.builder() + .enableDynamicField(true) .build(); schemaV2.addField(AddFieldReq.builder() .fieldName("id") .dataType(DataType.Int64) - .isPrimaryKey(Boolean.TRUE) + .isPrimaryKey(true) + .autoID(true) .build()); schemaV2.addField(AddFieldReq.builder() .fieldName("bool_field") @@ -73,40 +84,47 @@ void testV2AdapterUtils() { schemaV2.addField(AddFieldReq.builder() .fieldName("varchar_field") .dataType(DataType.VarChar) + .maxLength(100) .build()); schemaV2.addField(AddFieldReq.builder() .fieldName("json_field") .dataType(DataType.JSON) .build()); schemaV2.addField(AddFieldReq.builder() - .fieldName("arr_int_field") + .fieldName("arr_int32_field") .dataType(DataType.Array) - .maxCapacity(50) + .maxCapacity(20) .elementType(DataType.Int32) .build()); schemaV2.addField(AddFieldReq.builder() .fieldName("arr_float_field") .dataType(DataType.Array) - .maxCapacity(20) + .maxCapacity(10) .elementType(DataType.Float) .build()); schemaV2.addField(AddFieldReq.builder() .fieldName("arr_varchar_field") .dataType(DataType.Array) - .maxCapacity(10) + .maxLength(50) + .maxCapacity(5) .elementType(DataType.VarChar) .build()); schemaV2.addField(AddFieldReq.builder() .fieldName("float_vector_field") .dataType(DataType.FloatVector) - .dimension(128) + .dimension(DIMENSION) .build()); schemaV2.addField(AddFieldReq.builder() .fieldName("binary_vector_field") .dataType(DataType.BinaryVector) - .dimension(512) + .dimension(DIMENSION) .build()); + return schemaV2; + } + @Test + void testV2AdapterUtils() { + CreateCollectionReq.CollectionSchema schemaV2 = buildSchema(); CollectionSchemaParam schemaV1 = V2AdapterUtils.convertV2Schema(schemaV2); Assertions.assertEquals(schemaV2.isEnableDynamicField(), schemaV1.isEnableDynamicField()); @@ -143,4 +161,111 @@ void testV2AdapterUtils() { } } } + + private static void buildData(BulkWriter writer, int rowCount, boolean isEnableDynamicField) throws IOException, InterruptedException { + Random random = new Random(); + for (int i = 0; i < rowCount; ++i) { + JsonObject rowObject = new JsonObject(); + + // scalar field + rowObject.addProperty("bool_field", i % 5 == 0); + rowObject.addProperty("int8_field", i % 128); + rowObject.addProperty("int16_field", i % 1000); + rowObject.addProperty("int32_field", i % 100000); + rowObject.addProperty("int64_field", i); + rowObject.addProperty("float_field", i / 3); + rowObject.addProperty("double_field", i / 7); + rowObject.addProperty("varchar_field", "varchar_" + i); + rowObject.addProperty("json_field", String.format("{\"dummy\": %s, \"ok\": \"name_%s\"}", i, i)); + + // vector field + rowObject.add("float_vector_field", JsonUtils.toJsonTree(utils.generateFloatVector())); + rowObject.add("binary_vector_field", JsonUtils.toJsonTree(utils.generateBinaryVector().array())); + + // array field + rowObject.add("arr_int32_field", JsonUtils.toJsonTree(GeneratorUtils.generatorInt32Value(random.nextInt(20)))); + rowObject.add("arr_float_field", JsonUtils.toJsonTree(GeneratorUtils.generatorFloatValue(random.nextInt(10)))); + rowObject.add("arr_varchar_field", JsonUtils.toJsonTree(GeneratorUtils.generatorVarcharValue(random.nextInt(5), 5))); + + // dynamic fields + if (isEnableDynamicField) { + rowObject.addProperty("dynamic", "dynamic_" + i); + } + + writer.appendRow(rowObject); + } + } + + @Test + void testWriteParquet() { + try { + CreateCollectionReq.CollectionSchema schemaV2 = buildSchema(); + LocalBulkWriterParam bulkWriterParam = LocalBulkWriterParam.newBuilder() + .withCollectionSchema(schemaV2) + .withLocalPath("/tmp/bulk_writer") + .withFileType(BulkFileType.PARQUET) + .build(); + LocalBulkWriter localBulkWriter = new LocalBulkWriter(bulkWriterParam); + buildData(localBulkWriter, 10, schemaV2.isEnableDynamicField()); + + System.out.printf("%s rows appends%n", localBulkWriter.getTotalRowCount()); + System.out.printf("%s rows in buffer not flushed%n", localBulkWriter.getBufferRowCount()); + localBulkWriter.commit(false); + List> filePaths = localBulkWriter.getBatchFiles(); + System.out.println(filePaths); + Assertions.assertEquals(1, filePaths.size()); + Assertions.assertEquals(1, filePaths.get(0).size()); + } catch (Exception e) { + Assertions.fail(e.getMessage()); + } + } + + @Test + void testWriteJson() { + try { + CreateCollectionReq.CollectionSchema schemaV2 = buildSchema(); + LocalBulkWriterParam bulkWriterParam = LocalBulkWriterParam.newBuilder() + .withCollectionSchema(schemaV2) + .withLocalPath("/tmp/bulk_writer") + .withFileType(BulkFileType.JSON) + .build(); + LocalBulkWriter localBulkWriter = new LocalBulkWriter(bulkWriterParam); + buildData(localBulkWriter, 10, schemaV2.isEnableDynamicField()); + + System.out.printf("%s rows appends%n", localBulkWriter.getTotalRowCount()); + System.out.printf("%s rows in buffer not flushed%n", localBulkWriter.getBufferRowCount()); + localBulkWriter.commit(false); + List> filePaths = localBulkWriter.getBatchFiles(); + System.out.println(filePaths); + Assertions.assertEquals(1, filePaths.size()); + Assertions.assertEquals(1, filePaths.get(0).size()); + } catch (Exception e) { + Assertions.fail(e.getMessage()); + } + } + + @Test + void testWriteCSV() { + try { + CreateCollectionReq.CollectionSchema schemaV2 = buildSchema(); + LocalBulkWriterParam bulkWriterParam = LocalBulkWriterParam.newBuilder() + .withCollectionSchema(schemaV2) + .withLocalPath("/tmp/bulk_writer") + .withFileType(BulkFileType.CSV) + .withConfig("sep", ",") + .build(); + LocalBulkWriter localBulkWriter = new LocalBulkWriter(bulkWriterParam); + buildData(localBulkWriter, 10, schemaV2.isEnableDynamicField()); + + System.out.printf("%s rows appends%n", localBulkWriter.getTotalRowCount()); + System.out.printf("%s rows in buffer not flushed%n", localBulkWriter.getBufferRowCount()); + localBulkWriter.commit(false); + List> filePaths = localBulkWriter.getBatchFiles(); + System.out.println(filePaths); + Assertions.assertEquals(1, filePaths.size()); + Assertions.assertEquals(1, filePaths.get(0).size()); + } catch (Exception e) { + Assertions.fail(e.getMessage()); + } + } } diff --git a/src/test/java/io/milvus/client/MilvusClientDockerTest.java b/src/test/java/io/milvus/client/MilvusClientDockerTest.java index f06d30cb6..33ed138bf 100644 --- a/src/test/java/io/milvus/client/MilvusClientDockerTest.java +++ b/src/test/java/io/milvus/client/MilvusClientDockerTest.java @@ -22,6 +22,7 @@ import com.google.gson.*; import com.google.common.collect.Lists; import com.google.common.util.concurrent.ListenableFuture; +import io.milvus.TestUtils; import io.milvus.bulkwriter.LocalBulkWriter; import io.milvus.bulkwriter.LocalBulkWriterParam; import io.milvus.bulkwriter.common.clientenum.BulkFileType; @@ -29,7 +30,6 @@ import io.milvus.common.clientenum.ConsistencyLevelEnum; import io.milvus.common.utils.Float16Utils; import io.milvus.common.utils.JsonUtils; -import io.milvus.exception.ParamException; import io.milvus.grpc.*; import io.milvus.orm.iterator.QueryIterator; import io.milvus.orm.iterator.SearchIterator; @@ -73,17 +73,17 @@ @Testcontainers(disabledWithoutDocker = true) class MilvusClientDockerTest { - protected static MilvusClient client; - protected static RandomStringGenerator generator; - protected static final int DIMENSION = 128; - protected static final int ARRAY_CAPACITY = 100; - protected static final float FLOAT16_PRECISION = 0.001f; - protected static final float BFLOAT16_PRECISION = 0.01f; + private static MilvusClient client; + private static RandomStringGenerator generator; + private static final int DIMENSION = 128; + private static final int ARRAY_CAPACITY = 100; + private static final float FLOAT16_PRECISION = 0.001f; + private static final float BFLOAT16_PRECISION = 0.01f; - private static final Random RANDOM = new Random(); + private static final TestUtils utils = new TestUtils(DIMENSION); @Container - private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:master-20241111-fca946de-amd64"); + private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:master-20241121-b983ef9f-amd64"); @BeforeAll public static void setUp() { @@ -112,83 +112,6 @@ private static ConnectParam.Builder connectParamBuilder(String milvusUri) { return ConnectParam.newBuilder().withUri(milvusUri); } - private List generateFloatVector() { - List vector = new ArrayList<>(); - for (int i = 0; i < DIMENSION; ++i) { - vector.add(RANDOM.nextFloat()); - } - return vector; - } - - private List> generateFloatVectors(int count) { - List> vectors = new ArrayList<>(); - for (int n = 0; n < count; ++n) { - vectors.add(generateFloatVector()); - } - - return vectors; - } - - private ByteBuffer generateBinaryVector() { - int byteCount = DIMENSION / 8; - ByteBuffer vector = ByteBuffer.allocate(byteCount); - for (int i = 0; i < byteCount; ++i) { - vector.put((byte) RANDOM.nextInt(Byte.MAX_VALUE)); - } - return vector; - } - - private List generateBinaryVectors(int count) { - List vectors = new ArrayList<>(); - for (int n = 0; n < count; ++n) { - vectors.add(generateBinaryVector()); - } - return vectors; - } - - private ByteBuffer generateFloat16Vector() { - List vector = generateFloatVector(); - return Float16Utils.f32VectorToFp16Buffer(vector); - } - - private List generateFloat16Vectors(int count) { - List vectors = new ArrayList<>(); - for (int n = 0; n < count; ++n) { - vectors.add(generateFloat16Vector()); - } - return vectors; - } - - private ByteBuffer generateBFloat16Vector() { - List vector = generateFloatVector(); - return Float16Utils.f32VectorToBf16Buffer(vector); - } - - private List generateBFloat16Vectors(int count) { - List vectors = new ArrayList<>(); - for (int n = 0; n < count; ++n) { - vectors.add(generateBFloat16Vector()); - } - return vectors; - } - - private SortedMap generateSparseVector() { - SortedMap sparse = new TreeMap<>(); - int dim = RANDOM.nextInt(10) + 10; - for (int i = 0; i < dim; ++i) { - sparse.put((long) RANDOM.nextInt(1000000), RANDOM.nextFloat()); - } - return sparse; - } - - private List> generateSparseVectors(int count) { - List> vectors = new ArrayList<>(); - for (int n = 0; n < count; ++n) { - vectors.add(generateSparseVector()); - } - return vectors; - } - private CollectionSchemaParam buildSchema(boolean strID, boolean autoID, boolean enabledDynamicSchema, List fieldTypes) { CollectionSchemaParam.Builder builder = CollectionSchemaParam.newBuilder() .withEnableDynamicField(enabledDynamicSchema); @@ -248,71 +171,6 @@ private CollectionSchemaParam buildSchema(boolean strID, boolean autoID, boolean return builder.build(); } - private List generateRandomArray(FieldType field) { - DataType dataType = field.getDataType(); - if (dataType != DataType.Array) { - Assertions.fail(); - } - - DataType eleType = field.getElementType(); - int eleCnt = RANDOM.nextInt(field.getMaxCapacity()); - switch (eleType) { - case Bool: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(i%10 == 0); - } - return values; - } - case Int8: - case Int16: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add((short)RANDOM.nextInt(256)); - } - return values; - } - case Int32: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(RANDOM.nextInt()); - } - return values; - } - case Int64: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(RANDOM.nextLong()); - } - return values; - } - case Float: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(RANDOM.nextFloat()); - } - return values; - } - case Double: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(RANDOM.nextDouble()); - } - return values; - } - case VarChar: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(String.format("varchar_arr_%d", i)); - } - return values; - } - default: - Assertions.fail(); - } - return null; - } - private List generateColumnsData(CollectionSchemaParam schema, int count, int idStart) { List columns = new ArrayList<>(); List fieldTypes = schema.getFieldTypes(); @@ -391,33 +249,33 @@ private List generateColumnsData(CollectionSchemaParam schema case Array: { List> data = new ArrayList<>(); for (int i = idStart; i < idStart + count; ++i) { - data.add(generateRandomArray(fieldType)); + data.add(utils.generateRandomArray(fieldType.getElementType(), fieldType.getMaxCapacity())); } columns.add(new InsertParam.Field(fieldType.getName(), data)); break; } case FloatVector: { - List> data = generateFloatVectors(count); + List> data = utils.generateFloatVectors(count); columns.add(new InsertParam.Field(fieldType.getName(), data)); break; } case BinaryVector: { - List data = generateBinaryVectors(count); + List data = utils.generateBinaryVectors(count); columns.add(new InsertParam.Field(fieldType.getName(), data)); break; } case Float16Vector: { - List data = generateFloat16Vectors(count); + List data = utils.generateFloat16Vectors(count); columns.add(new InsertParam.Field(fieldType.getName(), data)); break; } case BFloat16Vector: { - List data = generateBFloat16Vectors(count); + List data = utils.generateBFloat16Vectors(count); columns.add(new InsertParam.Field(fieldType.getName(), data)); break; } case SparseFloatVector: { - List> data = generateSparseVectors(count); + List> data = utils.generateSparseVectors(count); columns.add(new InsertParam.Field(fieldType.getName(), data)); break; } @@ -476,22 +334,22 @@ private List generateRowsData(CollectionSchemaParam schema, int coun row.add(fieldType.getName(), info); break; case Array: - row.add(fieldType.getName(), JsonUtils.toJsonTree(generateRandomArray(fieldType))); + row.add(fieldType.getName(), JsonUtils.toJsonTree(utils.generateRandomArray(fieldType.getElementType(), fieldType.getMaxCapacity()))); break; case FloatVector: - row.add(fieldType.getName(), JsonUtils.toJsonTree(generateFloatVector())); + row.add(fieldType.getName(), JsonUtils.toJsonTree(utils.generateFloatVector())); break; case BinaryVector: - row.add(fieldType.getName(), JsonUtils.toJsonTree(generateBinaryVector().array())); + row.add(fieldType.getName(), JsonUtils.toJsonTree(utils.generateBinaryVector().array())); break; case Float16Vector: - row.add(fieldType.getName(), JsonUtils.toJsonTree(generateFloat16Vector().array())); + row.add(fieldType.getName(), JsonUtils.toJsonTree(utils.generateFloat16Vector().array())); break; case BFloat16Vector: - row.add(fieldType.getName(), JsonUtils.toJsonTree(generateBFloat16Vector().array())); + row.add(fieldType.getName(), JsonUtils.toJsonTree(utils.generateBFloat16Vector().array())); break; case SparseFloatVector: - row.add(fieldType.getName(), JsonUtils.toJsonTree(generateSparseVector())); + row.add(fieldType.getName(), JsonUtils.toJsonTree(utils.generateSparseVector())); break; default: Assertions.fail(); @@ -1254,7 +1112,7 @@ void testFloat16Vector() { // generate vectors int rowCount = 10000; - List> vectors = generateFloatVectors(rowCount); + List> vectors = utils.generateFloatVectors(rowCount); // insert by column-based List fp16Vectors = new ArrayList<>(); @@ -1482,7 +1340,7 @@ void testMultipleVectorFields() { // search on multiple vector fields AnnSearchParam param1 = AnnSearchParam.newBuilder() .withVectorFieldName(DataType.FloatVector.name()) - .withFloatVectors(generateFloatVectors(1)) + .withFloatVectors(utils.generateFloatVectors(1)) .withMetricType(MetricType.COSINE) .withParams("{\"nprobe\": 32}") .withTopK(10) @@ -1490,7 +1348,7 @@ void testMultipleVectorFields() { AnnSearchParam param2 = AnnSearchParam.newBuilder() .withVectorFieldName(DataType.BinaryVector.name()) - .withBinaryVectors(generateBinaryVectors(1)) + .withBinaryVectors(utils.generateBinaryVectors(1)) .withMetricType(MetricType.HAMMING) .withParams("{}") .withTopK(5) @@ -1498,7 +1356,7 @@ void testMultipleVectorFields() { AnnSearchParam param3 = AnnSearchParam.newBuilder() .withVectorFieldName(DataType.SparseFloatVector.name()) - .withSparseFloatVectors(generateSparseVectors(1)) + .withSparseFloatVectors(utils.generateSparseVectors(1)) .withMetricType(MetricType.IP) .withParams("{\"drop_ratio_search\":0.2}") .withTopK(7) @@ -1571,7 +1429,7 @@ void testAsyncMethods() { List>> futureResponses = new ArrayList<>(); int rowCount = 1000; for (long i = 0L; i < 10; ++i) { - List> vectors = generateFloatVectors(rowCount); + List> vectors = utils.generateFloatVectors(rowCount); List fieldsInsert = new ArrayList<>(); fieldsInsert.add(new InsertParam.Field(DataType.FloatVector.name(), vectors)); @@ -1631,7 +1489,7 @@ void testAsyncMethods() { Assertions.assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue()); // search async - List> targetVectors = generateFloatVectors(2); + List> targetVectors = utils.generateFloatVectors(2); int topK = 5; SearchParam searchParam = SearchParam.newBuilder() .withCollectionName(randomCollectionName) @@ -2239,7 +2097,7 @@ void testArrayField() { intArrArray.add(intArray); floatArrArray.add(floatArray); } - List> vectors = generateFloatVectors(rowCount); + List> vectors = utils.generateFloatVectors(rowCount); List fieldsInsert = new ArrayList<>(); fieldsInsert.add(new InsertParam.Field("id", ids)); @@ -2262,7 +2120,7 @@ void testArrayField() { for (int i = 0; i < rowCount; ++i) { JsonObject row = new JsonObject(); row.addProperty("id", 10000L + (long)i); - List vector = generateFloatVectors(1).get(0); + List vector = utils.generateFloatVectors(1).get(0); row.add(DataType.FloatVector.name(), JsonUtils.toJsonTree(vector)); List strArray = new ArrayList<>(); @@ -2290,7 +2148,7 @@ void testArrayField() { System.out.println(rowCount + " rows inserted"); // search - List> searchVectors = generateFloatVectors(1); + List> searchVectors = utils.generateFloatVectors(1); SearchParam searchParam = SearchParam.newBuilder() .withCollectionName(randomCollectionName) .withMetricType(MetricType.L2) @@ -2376,7 +2234,7 @@ void testUpsert() throws InterruptedException { for (long i = 0L; i < rowCount; ++i) { JsonObject row = new JsonObject(); row.addProperty("id", i); - List vector = generateFloatVectors(1).get(0); + List vector = utils.generateFloatVectors(1).get(0); row.add(DataType.FloatVector.name(), JsonUtils.toJsonTree(vector)); row.addProperty(DataType.VarChar.name(), String.format("name_%d", i)); row.addProperty("dynamic_value", String.format("dynamic_%d", i)); @@ -2450,7 +2308,7 @@ void testUpsert() throws InterruptedException { for (long i = 0L; i < rowCount; ++i) { JsonObject row = new JsonObject(); row.addProperty("id", rowCount + i); - List vector = generateFloatVectors(1).get(0); + List vector = utils.generateFloatVectors(1).get(0); row.add(DataType.FloatVector.name(), JsonUtils.toJsonTree(vector)); row.addProperty(DataType.VarChar.name(), String.format("name_%d", rowCount + i)); rows.add(row); @@ -2493,14 +2351,14 @@ void testUpsert() throws InterruptedException { rows.clear(); JsonObject row = new JsonObject(); row.addProperty("id", 5L); - List vector = generateFloatVectors(1).get(0); + List vector = utils.generateFloatVectors(1).get(0); row.add(DataType.FloatVector.name(), JsonUtils.toJsonTree(vector)); row.addProperty(DataType.VarChar.name(), "updated_5"); row.addProperty("dynamic_value", String.format("dynamic_%d", 5)); rows.add(row); row = new JsonObject(); row.addProperty("id", 18L); - vector = generateFloatVectors(1).get(0); + vector = utils.generateFloatVectors(1).get(0); row.add(DataType.FloatVector.name(), JsonUtils.toJsonTree(vector)); row.addProperty(DataType.VarChar.name(), "updated_18"); row.addProperty("dynamic_value", 18); @@ -2688,7 +2546,7 @@ void testCollectionHighLevelGet(FieldType primaryField, FieldType vectorField) { } else { row.addProperty(primaryField.getName(), String.valueOf(i)); } - List vector = generateFloatVectors(1).get(0); + List vector = utils.generateFloatVectors(1).get(0); row.add(vectorField.getName(), JsonUtils.toJsonTree(vector)); rows.add(row); primaryIds.add(String.valueOf(i)); @@ -2753,7 +2611,7 @@ void testCollectionHighLevelDelete(FieldType primaryField, FieldType vectorField } else { row.addProperty(primaryField.getName(), String.valueOf(i)); } - List vector = generateFloatVectors(1).get(0); + List vector = utils.generateFloatVectors(1).get(0); row.add(vectorField.getName(), JsonUtils.toJsonTree(vector)); rows.add(row); primaryIds.add(String.valueOf(i)); @@ -2844,10 +2702,10 @@ public void testBulkWriter() { row.add(DataType.Array.name() + "_varchar", JsonUtils.toJsonTree(Lists.newArrayList("aaa", "bbb", "ccc"))); row.add(DataType.Array.name() + "_int32", JsonUtils.toJsonTree(Lists.newArrayList(5, 6, 3, 2, 1))); row.add(DataType.Array.name() + "_float", JsonUtils.toJsonTree(Lists.newArrayList(0.5, 1.8))); - row.add(DataType.FloatVector.name(), JsonUtils.toJsonTree(generateFloatVector())); - row.add(DataType.BinaryVector.name(), JsonUtils.toJsonTree(generateBinaryVector().array())); - row.add(DataType.BFloat16Vector.name(), JsonUtils.toJsonTree(generateBFloat16Vector().array())); - row.add(DataType.SparseFloatVector.name(), JsonUtils.toJsonTree(generateSparseVector())); + row.add(DataType.FloatVector.name(), JsonUtils.toJsonTree(utils.generateFloatVector())); + row.add(DataType.BinaryVector.name(), JsonUtils.toJsonTree(utils.generateBinaryVector().array())); + row.add(DataType.BFloat16Vector.name(), JsonUtils.toJsonTree(utils.generateBFloat16Vector().array())); + row.add(DataType.SparseFloatVector.name(), JsonUtils.toJsonTree(utils.generateSparseVector())); if (enabledDynamic) { row.addProperty("dynamic_1", i); @@ -2907,7 +2765,7 @@ public void testIterator() { for (long i = 0L; i < rowCount; ++i) { JsonObject row = new JsonObject(); row.addProperty("id", Long.toString(i)); - row.add(DataType.FloatVector.name(), JsonUtils.toJsonTree(generateFloatVectors(1).get(0))); + row.add(DataType.FloatVector.name(), JsonUtils.toJsonTree(utils.generateFloatVectors(1).get(0))); JsonObject json = new JsonObject(); if (i%2 == 0) { json.addProperty("even", true); @@ -2978,7 +2836,7 @@ public void testIterator() { Assertions.assertEquals(300, counter); // search iterator - List> vectors = generateFloatVectors(1); + List> vectors = utils.generateFloatVectors(1); SearchIteratorParam.Builder searchIteratorParamBuilder = SearchIteratorParam.newBuilder() .withCollectionName(randomCollectionName) .withOutFields(Lists.newArrayList("*")) @@ -3079,7 +2937,7 @@ void testCacheCollectionSchema() { // insert JsonObject row = new JsonObject(); - row.add("vector", JsonUtils.toJsonTree(generateFloatVectors(1).get(0))); + row.add("vector", JsonUtils.toJsonTree(utils.generateFloatVectors(1).get(0))); R insertR = client.insert(InsertParam.newBuilder() .withCollectionName(randomCollectionName) .withRows(Collections.singletonList(row)) @@ -3214,7 +3072,7 @@ void testNullableAndDefaultValue() { List data = new ArrayList<>(); for (int i = 0; i < 10; i++) { JsonObject row = new JsonObject(); - List vector = generateFloatVector(); + List vector = utils.generateFloatVector(); row.addProperty("id", i); row.add("vector", JsonUtils.toJsonTree(vector)); if (i%2 == 0) { @@ -3233,7 +3091,7 @@ void testNullableAndDefaultValue() { Assertions.assertEquals(R.Status.Success.getCode(), insertR.getStatus().intValue()); // insert by column-based - List> vectors = generateFloatVectors(10); + List> vectors = utils.generateFloatVectors(10); List ids = new ArrayList<>(); List flags = new ArrayList<>(); List descs = new ArrayList<>(); @@ -3291,7 +3149,7 @@ void testNullableAndDefaultValue() { } // search the row-based items - List> searchVectors = generateFloatVectors(1); + List> searchVectors = utils.generateFloatVectors(1); SearchParam searchParam = SearchParam.newBuilder() .withCollectionName(randomCollectionName) .withMetricType(MetricType.L2) diff --git a/src/test/java/io/milvus/client/MilvusMultiClientDockerTest.java b/src/test/java/io/milvus/client/MilvusMultiClientDockerTest.java index c31b83e5f..2f188669b 100644 --- a/src/test/java/io/milvus/client/MilvusMultiClientDockerTest.java +++ b/src/test/java/io/milvus/client/MilvusMultiClientDockerTest.java @@ -20,6 +20,7 @@ package io.milvus.client; import com.google.common.util.concurrent.ListenableFuture; +import io.milvus.TestUtils; import io.milvus.grpc.*; import io.milvus.param.*; import io.milvus.param.collection.*; @@ -49,8 +50,9 @@ class MilvusMultiClientDockerTest { private static MilvusClient client; private static RandomStringGenerator generator; - private static final int dimension = 128; + private static final int DIMENSION = 128; private static final Boolean useDockerCompose = Boolean.TRUE; + private static final TestUtils utils = new TestUtils(DIMENSION); private static void waitMilvusServerReady(String host, int port) { ConnectParam connectParam = connectParamBuilder(host, port) @@ -192,50 +194,6 @@ private static MultiConnectParam.Builder multiConnectParamBuilder() { return MultiConnectParam.newBuilder().withHosts(Arrays.asList(serverAddress, serverSlaveAddress)); } - private List> generateFloatVectors(int count) { - Random ran = new Random(); - List> vectors = new ArrayList<>(); - for (int n = 0; n < count; ++n) { - List vector = new ArrayList<>(); - for (int i = 0; i < dimension; ++i) { - vector.add(ran.nextFloat()); - } - vectors.add(vector); - } - - return vectors; - } - - private List> normalizeFloatVectors(List> src) { - for (List vector : src) { - double total = 0.0; - for (Float val : vector) { - total = total + val * val; - } - float squre = (float) Math.sqrt(total); - for (int i = 0; i < vector.size(); ++i) { - vector.set(i, vector.get(i) / squre); - } - } - - return src; - } - - private List generateBinaryVectors(int count) { - Random ran = new Random(); - List vectors = new ArrayList<>(); - int byteCount = dimension / 8; - for (int n = 0; n < count; ++n) { - ByteBuffer vector = ByteBuffer.allocate(byteCount); - for (int i = 0; i < byteCount; ++i) { - vector.put((byte) ran.nextInt(Byte.MAX_VALUE)); - } - vectors.add(vector); - } - return vectors; - - } - @Test void testFloatVectors() { client.setLogLevel(LogLevel.Error); @@ -260,7 +218,7 @@ void testFloatVectors() { .withDataType(DataType.FloatVector) .withName(field2Name) .withDescription("face") - .withDimension(dimension) + .withDimension(DIMENSION) .build()); fieldsSchema.add(FieldType.newBuilder() @@ -310,7 +268,7 @@ void testFloatVectors() { weights.add(((double) (i + 1) / 100)); ages.add((short) ((i + 1) % 99)); } - List> vectors = generateFloatVectors(rowCount); + List> vectors = utils.generateFloatVectors(rowCount); List fieldsInsert = new ArrayList<>(); fieldsInsert.add(new InsertParam.Field(field1Name, ids)); @@ -529,7 +487,7 @@ void testBinaryVectors() { .withDataType(DataType.BinaryVector) .withName(field2Name) .withDescription("world") - .withDimension(dimension) + .withDimension(DIMENSION) .build(); // create collection @@ -549,7 +507,7 @@ void testBinaryVectors() { for (long i = 0L; i < rowCount; ++i) { ids.add(i); } - List vectors = generateBinaryVectors(rowCount); + List vectors = utils.generateBinaryVectors(rowCount); List fields = new ArrayList<>(); // no need to provide id here since this field is auto_id @@ -659,7 +617,7 @@ void testAsyncMethods() { .withDataType(DataType.FloatVector) .withName(field2Name) .withDescription("face") - .withDimension(dimension) + .withDimension(DIMENSION) .build()); // create collection @@ -676,7 +634,7 @@ void testAsyncMethods() { List>> futureResponses = new ArrayList<>(); int rowCount = 1000; for (long i = 0L; i < 10; ++i) { - List> vectors = normalizeFloatVectors(generateFloatVectors(rowCount)); + List> vectors = utils.generateFloatVectors(rowCount); List fieldsInsert = new ArrayList<>(); fieldsInsert.add(new InsertParam.Field(field2Name, vectors)); @@ -736,7 +694,7 @@ void testAsyncMethods() { assertEquals(R.Status.Success.getCode(), loadR.getStatus().intValue()); // search async - List> targetVectors = normalizeFloatVectors(generateFloatVectors(2)); + List> targetVectors = utils.generateFloatVectors(2); int topK = 5; SearchParam searchParam = SearchParam.newBuilder() .withCollectionName(randomCollectionName) @@ -770,7 +728,7 @@ void testAsyncMethods() { for (int i = 0; i < targetVectors.size(); ++i) { List scores = results.getIDScore(i); assertEquals(topK, scores.size()); - System.out.println(scores.toString()); + System.out.println(scores); } // get query results diff --git a/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java b/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java index 55cb39ea6..d6b19a62c 100644 --- a/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java +++ b/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java @@ -23,6 +23,7 @@ import com.google.gson.*; import com.google.gson.reflect.TypeToken; +import io.milvus.TestUtils; import io.milvus.common.clientenum.FunctionType; import io.milvus.common.utils.Float16Utils; import io.milvus.common.utils.JsonUtils; @@ -66,12 +67,12 @@ class MilvusClientV2DockerTest { private static MilvusClientV2 client; private static RandomStringGenerator generator; - private static final int dimension = 256; - + private static final int DIMENSION = 256; private static final Random RANDOM = new Random(); + private static final TestUtils utils = new TestUtils(DIMENSION); @Container - private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:master-20241111-fca946de-amd64"); + private static final MilvusContainer milvus = new MilvusContainer("milvusdb/milvus:master-20241121-b983ef9f-amd64"); @BeforeAll public static void setUp() { @@ -89,73 +90,6 @@ public static void tearDown() throws InterruptedException { } } - private List generateFolatVector(int dim) { - List vector = new ArrayList<>(); - for (int i = 0; i < dim; ++i) { - vector.add(RANDOM.nextFloat()); - } - return vector; - } - - private List generateFolatVector() { - return generateFolatVector(dimension); - } - - private List> generateFloatVectors(int count) { - List> vectors = new ArrayList<>(); - for (int n = 0; n < count; ++n) { - vectors.add(generateFolatVector()); - } - - return vectors; - } - - private ByteBuffer generateBinaryVector() { - int byteCount = dimension / 8; - ByteBuffer vector = ByteBuffer.allocate(byteCount); - for (int i = 0; i < byteCount; ++i) { - vector.put((byte) RANDOM.nextInt(Byte.MAX_VALUE)); - } - return vector; - } - - private List generateBinaryVectors(int count) { - List vectors = new ArrayList<>(); - for (int n = 0; n < count; ++n) { - vectors.add(generateBinaryVector()); - } - return vectors; - - } - - private ByteBuffer generateFloat16Vector() { - List vector = generateFolatVector(); - return Float16Utils.f32VectorToFp16Buffer(vector); - } - - private ByteBuffer generateBFloat16Vector() { - List vector = generateFolatVector(); - return Float16Utils.f32VectorToBf16Buffer(vector); - } - - private SortedMap generateSparseVector() { - SortedMap sparse = new TreeMap<>(); - int dim = RANDOM.nextInt(10) + 10; - for (int i = 0; i < dim; ++i) { - sparse.put((long) RANDOM.nextInt(1000000), RANDOM.nextFloat()); - } - return sparse; - } - - private List> generateSparseVectors(int count) { - List> vectors = new ArrayList<>(); - for (int n = 0; n < count; ++n) { - vectors.add(generateSparseVector()); - } - return vectors; - - } - private CreateCollectionReq.CollectionSchema baseSchema() { CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder() .build(); @@ -221,71 +155,6 @@ private CreateCollectionReq.CollectionSchema baseSchema() { return collectionSchema; } - private JsonArray generateRandomArray(CreateCollectionReq.FieldSchema field) { - DataType dataType = field.getDataType(); - if (dataType != DataType.Array) { - Assertions.fail(); - } - - DataType eleType = field.getElementType(); - int eleCnt = RANDOM.nextInt(field.getMaxCapacity()); - switch (eleType) { - case Bool: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(i%10 == 0); - } - return JsonUtils.toJsonTree(values).getAsJsonArray(); - } - case Int8: - case Int16: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add((short)RANDOM.nextInt(256)); - } - return JsonUtils.toJsonTree(values).getAsJsonArray(); - } - case Int32: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(RANDOM.nextInt()); - } - return JsonUtils.toJsonTree(values).getAsJsonArray(); - } - case Int64: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(RANDOM.nextLong()); - } - return JsonUtils.toJsonTree(values).getAsJsonArray(); - } - case Float: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(RANDOM.nextFloat()); - } - return JsonUtils.toJsonTree(values).getAsJsonArray(); - } - case Double: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(RANDOM.nextDouble()); - } - return JsonUtils.toJsonTree(values).getAsJsonArray(); - } - case VarChar: { - List values = new ArrayList<>(); - for (int i = 0; i < eleCnt; i++) { - values.add(String.format("varchar_arr_%d", i)); - } - return JsonUtils.toJsonTree(values).getAsJsonArray(); - } - default: - Assertions.fail(); - } - return null; - } - private List generateRandomData(CreateCollectionReq.CollectionSchema schema, long count) { List fields = schema.getFieldSchemaList(); List rows = new ArrayList<>(); @@ -326,32 +195,33 @@ private List generateRandomData(CreateCollectionReq.CollectionSchema break; } case Array: { - JsonArray array = generateRandomArray(field); + List values = utils.generateRandomArray(io.milvus.grpc.DataType.valueOf(field.getElementType().name()), field.getMaxCapacity()); + JsonArray array = JsonUtils.toJsonTree(values).getAsJsonArray(); row.add(field.getName(), array); break; } case FloatVector: { - List vector = generateFolatVector(); + List vector = utils.generateFloatVector(); row.add(field.getName(), JsonUtils.toJsonTree(vector)); break; } case BinaryVector: { - ByteBuffer vector = generateBinaryVector(); + ByteBuffer vector = utils.generateBinaryVector(); row.add(field.getName(), JsonUtils.toJsonTree(vector.array())); break; } case Float16Vector: { - ByteBuffer vector = generateFloat16Vector(); + ByteBuffer vector = utils.generateFloat16Vector(); row.add(field.getName(), JsonUtils.toJsonTree(vector.array())); break; } case BFloat16Vector: { - ByteBuffer vector = generateBFloat16Vector(); + ByteBuffer vector = utils.generateBFloat16Vector(); row.add(field.getName(), JsonUtils.toJsonTree(vector.array())); break; } case SparseFloatVector: { - SortedMap vector = generateSparseVector(); + SortedMap vector = utils.generateSparseVector(); row.add(field.getName(), JsonUtils.toJsonTree(vector)); break; } @@ -418,7 +288,7 @@ void testFloatVectors() { collectionSchema.addField(AddFieldReq.builder() .fieldName(vectorFieldName) .dataType(DataType.FloatVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); Map extraParams = new HashMap<>(); @@ -518,7 +388,7 @@ void testFloatVectors() { .collectionName(randomCollectionName) .partitionNames(Collections.singletonList(partitionName)) .annsField(vectorFieldName) - .data(Collections.singletonList(new FloatVec(generateFolatVector()))) + .data(Collections.singletonList(new FloatVec(utils.generateFloatVector()))) .topK(10) .build()); List> searchResults = searchResp.getSearchResults(); @@ -601,7 +471,7 @@ void testBinaryVectors() throws InterruptedException { collectionSchema.addField(AddFieldReq.builder() .fieldName(vectorFieldName) .dataType(DataType.BinaryVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); Map extraParams = new HashMap<>(); @@ -678,12 +548,12 @@ void testFloat16Vectors() { collectionSchema.addField(AddFieldReq.builder() .fieldName(float16Field) .dataType(DataType.Float16Vector) - .dimension(dimension) + .dimension(DIMENSION) .build()); collectionSchema.addField(AddFieldReq.builder() .fieldName(bfloat16Field) .dataType(DataType.BFloat16Vector) - .dimension(dimension) + .dimension(DIMENSION) .build()); List indexes = new ArrayList<>(); @@ -732,7 +602,7 @@ void testFloat16Vectors() { long targetID = 99; JsonObject row = data.get((int)targetID); List originVector = new ArrayList<>(); - for (int i = 0; i < dimension; ++i) { + for (int i = 0; i < DIMENSION; ++i) { originVector.add((float)1/(i+1)); } System.out.println("Original float32 vector: " + originVector); @@ -814,7 +684,7 @@ void testSparseVectors() { collectionSchema.addField(AddFieldReq.builder() .fieldName(vectorFieldName) .dataType(DataType.SparseFloatVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); Map extraParams = new HashMap<>(); @@ -881,17 +751,17 @@ void testHybridSearch() { collectionSchema.addField(AddFieldReq.builder() .fieldName("float_vector") .dataType(DataType.FloatVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); collectionSchema.addField(AddFieldReq.builder() .fieldName("binary_vector") .dataType(DataType.BinaryVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); collectionSchema.addField(AddFieldReq.builder() .fieldName("sparse_vector") .dataType(DataType.SparseFloatVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); List indexParams = new ArrayList<>(); @@ -946,9 +816,9 @@ void testHybridSearch() { List binaryVectors = new ArrayList<>(); List sparseVectors = new ArrayList<>(); for (int i = 0; i < nq; i++) { - floatVectors.add(new FloatVec(generateFolatVector())); - binaryVectors.add(new BinaryVec(generateBinaryVector())); - sparseVectors.add(new SparseFloatVec(generateSparseVector())); + floatVectors.add(new FloatVec(utils.generateFloatVector())); + binaryVectors.add(new BinaryVec(utils.generateBinaryVector())); + sparseVectors.add(new SparseFloatVec(utils.generateSparseVector())); } List searchRequests = new ArrayList<>(); @@ -1218,7 +1088,7 @@ void testIndex() { collectionSchema.addField(AddFieldReq.builder() .fieldName("vector") .dataType(DataType.FloatVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); List indexes = new ArrayList<>(); @@ -1344,12 +1214,12 @@ void testCacheCollectionSchema() { client.createCollection(CreateCollectionReq.builder() .collectionName(randomCollectionName) .autoID(true) - .dimension(dimension) + .dimension(DIMENSION) .build()); // insert JsonObject row = new JsonObject(); - row.add("vector", JsonUtils.toJsonTree(generateFloatVectors(1).get(0))); + row.add("vector", JsonUtils.toJsonTree(utils.generateFloatVectors(1).get(0))); InsertResp insertResp = client.insert(InsertReq.builder() .collectionName(randomCollectionName) .data(Collections.singletonList(row)) @@ -1394,22 +1264,22 @@ public void testIterator() { collectionSchema.addField(AddFieldReq.builder() .fieldName("float_vector") .dataType(DataType.FloatVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); collectionSchema.addField(AddFieldReq.builder() .fieldName("binary_vector") .dataType(DataType.BinaryVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); collectionSchema.addField(AddFieldReq.builder() .fieldName("sparse_vector") .dataType(DataType.SparseFloatVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); collectionSchema.addField(AddFieldReq.builder() .fieldName("bfloat16_vector") .dataType(DataType.BFloat16Vector) - .dimension(dimension) + .dimension(DIMENSION) .build()); List indexParams = new ArrayList<>(); @@ -1461,7 +1331,7 @@ public void testIterator() { .outputFields(Lists.newArrayList("*")) .batchSize(20L) .vectorFieldName("float_vector") - .vectors(Collections.singletonList(new FloatVec(generateFolatVector()))) + .vectors(Collections.singletonList(new FloatVec(utils.generateFloatVector()))) .expr("int64_field > 500 && int64_field < 1000") .params("{\"range_filter\": 5.0, \"radius\": 50.0}") .topK(1000) @@ -1511,13 +1381,13 @@ public void testIterator() { Assertions.assertTrue(intArr.size() <= 50); // max capacity 50 is defined in the baseSchema() List floatVector = (List)record.get("float_vector"); - Assertions.assertEquals(dimension, floatVector.size()); + Assertions.assertEquals(DIMENSION, floatVector.size()); ByteBuffer binaryVector = (ByteBuffer)record.get("binary_vector"); - Assertions.assertEquals(dimension, binaryVector.limit()*8); + Assertions.assertEquals(DIMENSION, binaryVector.limit()*8); ByteBuffer bfloat16Vector = (ByteBuffer)record.get("bfloat16_vector"); - Assertions.assertEquals(dimension*2, bfloat16Vector.limit()); + Assertions.assertEquals(DIMENSION*2, bfloat16Vector.limit()); SortedMap sparseVector = (SortedMap)record.get("sparse_vector"); Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() <= 20); // defined in generateSparseVector() @@ -1576,13 +1446,13 @@ public void testIterator() { Assertions.assertTrue(intArr.size() <= 50); // max capacity 50 is defined in the baseSchema() List floatVector = (List)record.get("float_vector"); - Assertions.assertEquals(dimension, floatVector.size()); + Assertions.assertEquals(DIMENSION, floatVector.size()); ByteBuffer binaryVector = (ByteBuffer)record.get("binary_vector"); - Assertions.assertEquals(dimension, binaryVector.limit()*8); + Assertions.assertEquals(DIMENSION, binaryVector.limit()*8); ByteBuffer bfloat16Vector = (ByteBuffer)record.get("bfloat16_vector"); - Assertions.assertEquals(dimension*2, bfloat16Vector.limit()); + Assertions.assertEquals(DIMENSION*2, bfloat16Vector.limit()); SortedMap sparseVector = (SortedMap)record.get("sparse_vector"); Assertions.assertTrue(sparseVector.size() >= 10 && sparseVector.size() <= 20); // defined in generateSparseVector() @@ -1650,7 +1520,7 @@ void testDatabase() { collectionSchema.addField(AddFieldReq.builder() .fieldName(vectorFieldName) .dataType(DataType.FloatVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); IndexParam indexParam = IndexParam.builder() @@ -1781,7 +1651,7 @@ void testMultiThreadsInsert() { for (int j = 0; j < cnt; j++) { JsonObject obj = new JsonObject(); obj.addProperty("id", String.format("%d", i*cnt + j)); - List vector = generateFolatVector(dim); + List vector = utils.generateFloatVector(dim); obj.add("vector", JsonUtils.toJsonTree(vector)); obj.addProperty("dataTime", System.currentTimeMillis()); rows.add(obj); @@ -1825,7 +1695,7 @@ void testMultiThreadsInsert() { for (int j = 0; j < cnt; j++) { JsonObject obj = new JsonObject(); obj.addProperty("id", String.format("%d", i*cnt + j)); - List vector = generateFolatVector(dim); + List vector = utils.generateFloatVector(dim); obj.add("vector", JsonUtils.toJsonTree(vector)); obj.addProperty("dataTime", System.currentTimeMillis()); rows.add(obj); @@ -1908,7 +1778,7 @@ void testNullableAndDefaultValue() { List data = new ArrayList<>(); for (int i = 0; i < 10; i++) { JsonObject row = new JsonObject(); - List vector = generateFolatVector(dim); + List vector = utils.generateFloatVector(dim); row.addProperty("id", i); row.add("vector", JsonUtils.toJsonTree(vector)); if (i%2 == 0) { @@ -1954,7 +1824,7 @@ void testNullableAndDefaultValue() { SearchResp searchResp = client.search(SearchReq.builder() .collectionName(randomCollectionName) .annsField("vector") - .data(Collections.singletonList(new FloatVec(generateFolatVector(dim)))) + .data(Collections.singletonList(new FloatVec(utils.generateFloatVector(dim)))) .topK(10) .outputFields(Lists.newArrayList("*")) .consistencyLevel(ConsistencyLevel.BOUNDED) @@ -1993,7 +1863,7 @@ void testDocInOut() { collectionSchema.addField(AddFieldReq.builder() .fieldName("dense") .dataType(DataType.FloatVector) - .dimension(dimension) + .dimension(DIMENSION) .build()); collectionSchema.addField(AddFieldReq.builder() .fieldName("sparse") @@ -2071,7 +1941,7 @@ void testDocInOut() { for (int i = 0; i < texts.size(); i++) { JsonObject row = new JsonObject(); row.addProperty("id", i); - row.add("dense", JsonUtils.toJsonTree(generateFolatVector(dimension))); + row.add("dense", JsonUtils.toJsonTree(utils.generateFloatVector(DIMENSION))); row.addProperty("text", texts.get(i)); data.add(row); } @@ -2090,7 +1960,7 @@ void testDocInOut() { SearchResp searchResp = client.search(SearchReq.builder() .collectionName(randomCollectionName) .annsField("sparse") - .data(Collections.singletonList(new EmbeddedText("Vector and AI"))) + .data(Collections.singletonList(new EmbeddedText("milvus AI"))) .topK(10) .outputFields(Lists.newArrayList("*")) .metricType(IndexParam.MetricType.BM25)