Skip to content

Commit

Permalink
enhance: add csv support for bulkinsert (#34938)
Browse files Browse the repository at this point in the history
See this issue for details: #34937

---------

Signed-off-by: OxalisCu <[email protected]>
  • Loading branch information
OxalisCu authored Aug 21, 2024
1 parent ba6db11 commit ed4eaff
Show file tree
Hide file tree
Showing 12 changed files with 1,113 additions and 1 deletion.
132 changes: 132 additions & 0 deletions internal/util/importutilv2/csv/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package csv

import (
"context"
"encoding/csv"
"fmt"
"io"

"go.uber.org/atomic"
"go.uber.org/zap"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
)

type Row = map[storage.FieldID]any

type reader struct {
ctx context.Context
cm storage.ChunkManager
schema *schemapb.CollectionSchema

cr *csv.Reader
parser RowParser

fileSize *atomic.Int64
bufferSize int
count int64
filePath string
}

func NewReader(ctx context.Context, cm storage.ChunkManager, schema *schemapb.CollectionSchema, path string, bufferSize int, sep rune) (*reader, error) {
cmReader, err := cm.Reader(ctx, path)
if err != nil {
return nil, merr.WrapErrImportFailed(fmt.Sprintf("read csv file failed, path=%s, err=%s", path, err.Error()))
}
// count, err := estimateReadCountPerBatch(bufferSize, schema)
// if err != nil {
// return nil, err
// }

// set the interval for determining if the buffer is exceeded
var count int64 = 1000

csvReader := csv.NewReader(cmReader)
csvReader.Comma = sep

header, err := csvReader.Read()
log.Info("csv header parsed", zap.Strings("header", header))
if err != nil {
return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to read csv header, error: %v", err))
}

rowParser, err := NewRowParser(schema, header)
if err != nil {
return nil, err
}
return &reader{
ctx: ctx,
cm: cm,
schema: schema,
cr: csvReader,
parser: rowParser,
fileSize: atomic.NewInt64(0),
filePath: path,
bufferSize: bufferSize,
count: count,
}, nil
}

func (r *reader) Read() (*storage.InsertData, error) {
insertData, err := storage.NewInsertData(r.schema)
if err != nil {
return nil, err
}
var cnt int64 = 0
for {
value, err := r.cr.Read()
if err == io.EOF || len(value) == 0 {
break
}
row, err := r.parser.Parse(value)
if err != nil {
return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to parse row, error: %v", err))
}
err = insertData.Append(row)
if err != nil {
return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to append row, error: %v", err))
}
cnt++
if cnt >= r.count {
cnt = 0
if insertData.GetMemorySize() >= r.bufferSize {
break
}
}
}

// finish reading
if insertData.GetRowNum() == 0 {
return nil, io.EOF
}

return insertData, nil
}

func (r *reader) Close() {}

func (r *reader) Size() (int64, error) {
if size := r.fileSize.Load(); size != 0 {
return size, nil
}
size, err := r.cm.Size(r.ctx, r.filePath)
if err != nil {
return 0, err
}
r.fileSize.Store(size)
return size, nil
}

// func estimateReadCountPerBatch(bufferSize int, schema *schemapb.CollectionSchema) (int64, error) {
// sizePerRecord, err := typeutil.EstimateMaxSizePerRecord(schema)
// if err != nil {
// return 0, err
// }
// if 1000*sizePerRecord <= bufferSize {
// return 1000, nil
// }
// return int64(bufferSize) / int64(sizePerRecord), nil
// }
173 changes: 173 additions & 0 deletions internal/util/importutilv2/csv/reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package csv

import (
"context"
"encoding/csv"
"fmt"
"math/rand"
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/testutil"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)

type ReaderSuite struct {
suite.Suite

numRows int
pkDataType schemapb.DataType
vecDataType schemapb.DataType
}

func (suite *ReaderSuite) SetupSuite() {
paramtable.Get().Init(paramtable.NewBaseTable())
}

func (suite *ReaderSuite) SetupTest() {
suite.numRows = 10
suite.pkDataType = schemapb.DataType_Int64
suite.vecDataType = schemapb.DataType_FloatVector
}

func (suite *ReaderSuite) run(dataType schemapb.DataType, elemType schemapb.DataType) {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "pk",
IsPrimaryKey: true,
DataType: suite.pkDataType,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.MaxLengthKey,
Value: "128",
},
},
},
{
FieldID: 101,
Name: "vec",
DataType: suite.vecDataType,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "8",
},
},
},
{
FieldID: 102,
Name: dataType.String(),
DataType: dataType,
ElementType: elemType,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.MaxLengthKey,
Value: "128",
},
},
},
},
}

// generate csv data
insertData, err := testutil.CreateInsertData(schema, suite.numRows)
suite.NoError(err)
csvData, err := testutil.CreateInsertDataForCSV(schema, insertData)
suite.NoError(err)

// write to csv file
sep := '\t'
filePath := fmt.Sprintf("/tmp/test_%d_reader.csv", rand.Int())
defer os.Remove(filePath)
wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666)
assert.NoError(suite.T(), err)
writer := csv.NewWriter(wf)
writer.Comma = sep
writer.WriteAll(csvData)
suite.NoError(err)

// read from csv file
ctx := context.Background()
f := storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus_test/test_csv_reader/"))
cm, err := f.NewPersistentStorageChunkManager(ctx)
suite.NoError(err)

// check reader separate fields by '\t'
wrongSep := ','
_, err = NewReader(ctx, cm, schema, filePath, 64*1024*1024, wrongSep)
suite.Error(err)
suite.Contains(err.Error(), "value of field is missed: ")

// check data
reader, err := NewReader(ctx, cm, schema, filePath, 64*1024*1024, sep)
suite.NoError(err)

checkFn := func(actualInsertData *storage.InsertData, offsetBegin, expectRows int) {
expectInsertData := insertData
for fieldID, data := range actualInsertData.Data {
suite.Equal(expectRows, data.RowNum())
for i := 0; i < expectRows; i++ {
expect := expectInsertData.Data[fieldID].GetRow(i + offsetBegin)
actual := data.GetRow(i)
suite.Equal(expect, actual)
}
}
}

res, err := reader.Read()
suite.NoError(err)
checkFn(res, 0, suite.numRows)
}

func (suite *ReaderSuite) TestReadScalarFields() {
suite.run(schemapb.DataType_Bool, schemapb.DataType_None)
suite.run(schemapb.DataType_Int8, schemapb.DataType_None)
suite.run(schemapb.DataType_Int16, schemapb.DataType_None)
suite.run(schemapb.DataType_Int32, schemapb.DataType_None)
suite.run(schemapb.DataType_Int64, schemapb.DataType_None)
suite.run(schemapb.DataType_Float, schemapb.DataType_None)
suite.run(schemapb.DataType_Double, schemapb.DataType_None)
suite.run(schemapb.DataType_String, schemapb.DataType_None)
suite.run(schemapb.DataType_VarChar, schemapb.DataType_None)
suite.run(schemapb.DataType_JSON, schemapb.DataType_None)

suite.run(schemapb.DataType_Array, schemapb.DataType_Bool)
suite.run(schemapb.DataType_Array, schemapb.DataType_Int8)
suite.run(schemapb.DataType_Array, schemapb.DataType_Int16)
suite.run(schemapb.DataType_Array, schemapb.DataType_Int32)
suite.run(schemapb.DataType_Array, schemapb.DataType_Int64)
suite.run(schemapb.DataType_Array, schemapb.DataType_Float)
suite.run(schemapb.DataType_Array, schemapb.DataType_Double)
suite.run(schemapb.DataType_Array, schemapb.DataType_String)
}

func (suite *ReaderSuite) TestStringPK() {
suite.pkDataType = schemapb.DataType_VarChar
suite.run(schemapb.DataType_Int32, schemapb.DataType_None)
}

func (suite *ReaderSuite) TestVector() {
suite.vecDataType = schemapb.DataType_BinaryVector
suite.run(schemapb.DataType_Int32, schemapb.DataType_None)
suite.vecDataType = schemapb.DataType_FloatVector
suite.run(schemapb.DataType_Int32, schemapb.DataType_None)
suite.vecDataType = schemapb.DataType_Float16Vector
suite.run(schemapb.DataType_Int32, schemapb.DataType_None)
suite.vecDataType = schemapb.DataType_BFloat16Vector
suite.run(schemapb.DataType_Int32, schemapb.DataType_None)
suite.vecDataType = schemapb.DataType_SparseFloatVector
suite.run(schemapb.DataType_Int32, schemapb.DataType_None)
}

func TestUtil(t *testing.T) {
suite.Run(t, new(ReaderSuite))
}
Loading

0 comments on commit ed4eaff

Please sign in to comment.