Skip to content

Commit

Permalink
BulkWriter supports Json/CSV (#1193)
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <[email protected]>
  • Loading branch information
yhmo authored Nov 26, 2024
1 parent b92a8fb commit 235f170
Show file tree
Hide file tree
Showing 12 changed files with 582 additions and 471 deletions.
58 changes: 30 additions & 28 deletions examples/main/java/io/milvus/v2/SimpleExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,35 +103,37 @@ public static void main(String[] args) {
System.out.printf("ID: %d, Score: %f, %s\n", (long)result.getId(), result.getScore(), result.getEntity().toString());
}
}

// search with template expression
// Map<String, Map<String, Object>> expressionTemplateValues = new HashMap<>();
// Map<String, Object> params = new HashMap<>();
// params.put("max", 10);
// expressionTemplateValues.put("id < {max}", params);
//
// List<Object> list = Arrays.asList(1, 2, 3);
// Map<String, Object> params2 = new HashMap<>();
// params2.put("list", list);
// expressionTemplateValues.put("id in {list}", params2);
//
// expressionTemplateValues.forEach((key, value) -> {
// SearchReq request = SearchReq.builder()
// .collectionName(collectionName)
// .data(Collections.singletonList(new FloatVec(new float[]{1.0f, 1.0f, 1.0f, 1.0f})))
// .topK(10)
// .filter(key)
// .filterTemplateValues(value)
// .outputFields(Collections.singletonList("*"))
// .build();
// SearchResp statusR = client.search(request);
// List<List<SearchResp.SearchResult>> searchResults2 = statusR.getSearchResults();
// System.out.println("\nSearch results:");
// for (List<SearchResp.SearchResult> results : searchResults2) {
// for (SearchResp.SearchResult result : results) {
// System.out.printf("ID: %d, Score: %f, %s\n", (long)result.getId(), result.getScore(), result.getEntity().toString());
// }
// }
// });
Map<String, Map<String, Object>> expressionTemplateValues = new HashMap<>();
Map<String, Object> params = new HashMap<>();
params.put("max", 10);
expressionTemplateValues.put("id < {max}", params);

List<Object> list = Arrays.asList(1, 2, 3);
Map<String, Object> params2 = new HashMap<>();
params2.put("list", list);
expressionTemplateValues.put("id in {list}", params2);

expressionTemplateValues.forEach((key, value) -> {
SearchReq request = SearchReq.builder()
.collectionName(collectionName)
.data(Collections.singletonList(new FloatVec(new float[]{1.0f, 1.0f, 1.0f, 1.0f})))
.topK(10)
.filter(key)
.filterTemplateValues(value)
.outputFields(Collections.singletonList("*"))
.build();
SearchResp statusR = client.search(request);
List<List<SearchResp.SearchResult>> searchResults2 = statusR.getSearchResults();
System.out.println("\nSearch with template results:");
for (List<SearchResp.SearchResult> results : searchResults2) {
for (SearchResp.SearchResult result : results) {
System.out.printf("ID: %d, Score: %f, %s\n", (long)result.getId(), result.getScore(), result.getEntity().toString());
}
}
});

client.close();
}
}
110 changes: 101 additions & 9 deletions src/main/java/io/milvus/bulkwriter/Buffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package io.milvus.bulkwriter;

import com.google.common.collect.Lists;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import io.milvus.bulkwriter.common.clientenum.BulkFileType;
import io.milvus.common.utils.ExceptionUtils;
import io.milvus.bulkwriter.common.utils.ParquetUtils;
Expand All @@ -39,12 +41,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.*;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.*;
import java.util.stream.Collectors;

import static io.milvus.param.Constant.DYNAMIC_FIELD_NAME;
Expand All @@ -62,8 +61,8 @@ public Buffer(CollectionSchemaParam collectionSchema, BulkFileType fileType) {
this.collectionSchema = collectionSchema;
this.fileType = fileType;

buffer = new HashMap<>();
fields = new HashMap<>();
buffer = new LinkedHashMap<>();
fields = new LinkedHashMap<>();

for (FieldType fieldType : collectionSchema.getFieldTypes()) {
if (fieldType.isPrimaryKey() && fieldType.isAutoID())
Expand Down Expand Up @@ -103,7 +102,7 @@ public void appendRow(Map<String, Object> row) {
}

// verify row count of fields are equal
public List<String> persist(String localPath, Integer bufferSize, Integer bufferRowCount) {
public List<String> persist(String localPath, Map<String, Object> config) throws IOException {
int rowCount = -1;
for (String key : buffer.keySet()) {
if (rowCount < 0) {
Expand All @@ -116,13 +115,21 @@ public List<String> persist(String localPath, Integer bufferSize, Integer buffer

// output files
if (fileType == BulkFileType.PARQUET) {
Integer bufferSize = (Integer) config.get("bufferSize");
Integer bufferRowCount = (Integer) config.get("bufferRowCount");
return persistParquet(localPath, bufferSize, bufferRowCount);
} else if (fileType == BulkFileType.JSON) {
return persistJSON(localPath);
} else if (fileType == BulkFileType.CSV) {
String separator = (String)config.getOrDefault("sep", "\t");
String nullKey = (String)config.getOrDefault("nullkey", "");
return persistCSV(localPath, separator, nullKey);
}
ExceptionUtils.throwUnExpectedException("Unsupported file type: " + fileType);
return null;
}

private List<String> persistParquet(String localPath, Integer bufferSize, Integer bufferRowCount) {
private List<String> persistParquet(String localPath, Integer bufferSize, Integer bufferRowCount) throws IOException {
String filePath = localPath + ".parquet";

// calculate a proper row group size
Expand Down Expand Up @@ -178,6 +185,7 @@ private List<String> persistParquet(String localPath, Integer bufferSize, Intege
}
} catch (IOException e) {
e.printStackTrace();
throw e;
}

String msg = String.format("Successfully persist file %s, total size: %s, row count: %s, row group size: %s",
Expand All @@ -186,6 +194,90 @@ private List<String> persistParquet(String localPath, Integer bufferSize, Intege
return Lists.newArrayList(filePath);
}

private List<String> persistJSON(String localPath) throws IOException {
String filePath = localPath + ".json";

Gson gson = new GsonBuilder().serializeNulls().create();
List<Map<String, Object>> data = new ArrayList<>();

List<String> fieldNameList = Lists.newArrayList(buffer.keySet());
int size = buffer.get(fieldNameList.get(0)).size();
for (int i = 0; i < size; ++i) {
Map<String, Object> row = new HashMap<>();
for (String fieldName : fieldNameList) {
if (buffer.get(fieldName).get(i) instanceof ByteBuffer) {
row.put(fieldName, ((ByteBuffer)buffer.get(fieldName).get(i)).array());
} else {
row.put(fieldName, buffer.get(fieldName).get(i));
}
}
data.add(row);
}

try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(filePath))) {
bufferedWriter.write("[\n");
for (int i = 0; i < data.size(); i++) {
String json = gson.toJson(data.get(i));
if (i != data.size()-1) {
json += ",";
}
bufferedWriter.write(json);
bufferedWriter.newLine();
}
bufferedWriter.write("]\n");
} catch (IOException e) {
e.printStackTrace();
throw e;
}

return Lists.newArrayList(filePath);
}

private List<String> persistCSV(String localPath, String separator, String nullKey) throws IOException {
String filePath = localPath + ".csv";

Gson gson = new GsonBuilder().serializeNulls().create();
List<String> fieldNameList = Lists.newArrayList(buffer.keySet());
try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(filePath))) {
bufferedWriter.write(String.join(separator, fieldNameList));
bufferedWriter.newLine();
int size = buffer.get(fieldNameList.get(0)).size();
for (int i = 0; i < size; ++i) {
List<String> values = new ArrayList<>();
for (String fieldName : fieldNameList) {
Object val = buffer.get(fieldName).get(i);
String strVal = "";
if (val == null) {
strVal = nullKey;
} else if (val instanceof ByteBuffer) {
strVal = Arrays.toString(((ByteBuffer) val).array());
} else if (val instanceof List || val instanceof Map) {
strVal = gson.toJson(val); // server-side is using json to parse array field and vector field
} else {
strVal = val.toString();
}

// CSV format, all the single quotation should be replaced by double quotation
if (strVal.startsWith("\"") && strVal.endsWith("\"")) {
strVal = strVal.substring(1, strVal.length() - 1);
}
strVal = strVal.replace("\\\"", "\"");
strVal = strVal.replace("\"", "\"\"");
strVal = "\"" + strVal + "\"";
values.add(strVal);
}

bufferedWriter.write(String.join(separator, values));
bufferedWriter.newLine();
}
} catch (IOException e) {
e.printStackTrace();
throw e;
}

return Lists.newArrayList(filePath);
}

private void appendGroup(Group group, String paramName, Object value, FieldType fieldType) {
DataType dataType = fieldType.getDataType();
switch (dataType) {
Expand Down
28 changes: 21 additions & 7 deletions src/main/java/io/milvus/bulkwriter/LocalBulkWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class LocalBulkWriter extends BulkWriter implements AutoCloseable {
private Map<String, Thread> workingThread;
private ReentrantLock workingThreadLock;
private List<List<String>> localFiles;
private final Map<String, Object> config;

public LocalBulkWriter(LocalBulkWriterParam bulkWriterParam) throws IOException {
super(bulkWriterParam.getCollectionSchema(), bulkWriterParam.getChunkSize(), bulkWriterParam.getFileType());
Expand All @@ -54,16 +55,22 @@ public LocalBulkWriter(LocalBulkWriterParam bulkWriterParam) throws IOException
this.workingThreadLock = new ReentrantLock();
this.workingThread = new HashMap<>();
this.localFiles = Lists.newArrayList();
this.config = bulkWriterParam.getConfig();
this.makeDir();
}

protected LocalBulkWriter(CollectionSchemaParam collectionSchema, int chunkSize, BulkFileType fileType, String localPath) throws IOException {
protected LocalBulkWriter(CollectionSchemaParam collectionSchema,
int chunkSize,
BulkFileType fileType,
String localPath,
Map<String, Object> config) throws IOException {
super(collectionSchema, chunkSize, fileType);
this.localPath = localPath;
this.uuid = UUID.randomUUID().toString();
this.workingThreadLock = new ReentrantLock();
this.workingThread = new HashMap<>();
this.localFiles = Lists.newArrayList();
this.config = config;
this.makeDir();
}

Expand All @@ -84,7 +91,7 @@ public void appendRow(JsonObject rowData) throws IOException, InterruptedExcepti

public void commit(boolean async) throws InterruptedException {
// _async=True, the flush thread is asynchronously
while (workingThread.size() > 0) {
while (!workingThread.isEmpty()) {
String msg = String.format("Previous flush action is not finished, %s is waiting...", Thread.currentThread().getName());
logger.info(msg);
TimeUnit.SECONDS.sleep(5);
Expand Down Expand Up @@ -116,13 +123,20 @@ private void flush(Integer bufferSize, Integer bufferRowCount) {
java.nio.file.Path path = Paths.get(localPath);
java.nio.file.Path flushDirPath = path.resolve(String.valueOf(flushCount));

Map<String, Object> config = new HashMap<>(this.config);
config.put("bufferSize", bufferSize);
config.put("bufferRowCount", bufferRowCount);
Buffer oldBuffer = super.newBuffer();
if (oldBuffer.getRowCount() > 0) {
List<String> fileList = oldBuffer.persist(
flushDirPath.toString(), bufferSize, bufferRowCount
);
localFiles.add(fileList);
callBack(fileList);
try {
List<String> fileList = oldBuffer.persist(flushDirPath.toString(), config);
localFiles.add(fileList);
callBack(fileList);
} catch (IOException e) {
// this function is running in a thread
// TODO: interrupt main thread if failed to persist file
logger.error(e.getMessage());
}
}
workingThread.remove(Thread.currentThread().getName());
String msg = String.format("Flush thread done, name: %s", Thread.currentThread().getName());
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/io/milvus/bulkwriter/LocalBulkWriterParam.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import lombok.NonNull;
import lombok.ToString;

import java.util.HashMap;
import java.util.Map;

/**
* Parameters for <code>bulkWriter</code> interface.
*/
Expand All @@ -39,12 +42,14 @@ public class LocalBulkWriterParam {
private final String localPath;
private final int chunkSize;
private final BulkFileType fileType;
private final Map<String, Object> config;

private LocalBulkWriterParam(@NonNull Builder builder) {
this.collectionSchema = builder.collectionSchema;
this.localPath = builder.localPath;
this.chunkSize = builder.chunkSize;
this.fileType = builder.fileType;
this.config = builder.config;
}

public static Builder newBuilder() {
Expand All @@ -59,6 +64,7 @@ public static final class Builder {
private String localPath;
private int chunkSize = 128 * 1024 * 1024;
private BulkFileType fileType = BulkFileType.PARQUET;
private Map<String, Object> config = new HashMap<>();

private Builder() {
}
Expand Down Expand Up @@ -106,6 +112,11 @@ public Builder withFileType(BulkFileType fileType) {
return this;
}

public Builder withConfig(String key, Object val) {
this.config.put(key, val);
return this;
}

/**
* Verifies parameters and creates a new {@link LocalBulkWriterParam} instance.
*
Expand Down
8 changes: 7 additions & 1 deletion src/main/java/io/milvus/bulkwriter/RemoteBulkWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class RemoteBulkWriter extends LocalBulkWriter {
private static final Logger logger = LoggerFactory.getLogger(RemoteBulkWriter.class);
Expand All @@ -54,7 +56,11 @@ public class RemoteBulkWriter extends LocalBulkWriter {
private List<List<String>> remoteFiles;

public RemoteBulkWriter(RemoteBulkWriterParam bulkWriterParam) throws IOException {
super(bulkWriterParam.getCollectionSchema(), bulkWriterParam.getChunkSize(), bulkWriterParam.getFileType(), generatorLocalPath());
super(bulkWriterParam.getCollectionSchema(),
bulkWriterParam.getChunkSize(),
bulkWriterParam.getFileType(),
generatorLocalPath(),
bulkWriterParam.getConfig());
Path path = Paths.get(bulkWriterParam.getRemotePath());
Path remoteDirPath = path.resolve(getUUID());
this.remotePath = remoteDirPath.toString();
Expand Down
Loading

0 comments on commit 235f170

Please sign in to comment.