diff --git a/internal/util/importutilv2/csv/reader.go b/internal/util/importutilv2/csv/reader.go new file mode 100644 index 0000000000000..a0f6c6c1a3794 --- /dev/null +++ b/internal/util/importutilv2/csv/reader.go @@ -0,0 +1,132 @@ +package csv + +import ( + "context" + "encoding/csv" + "fmt" + "io" + + "go.uber.org/atomic" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type Row = map[storage.FieldID]any + +type reader struct { + ctx context.Context + cm storage.ChunkManager + schema *schemapb.CollectionSchema + + cr *csv.Reader + parser RowParser + + fileSize *atomic.Int64 + bufferSize int + count int64 + filePath string +} + +func NewReader(ctx context.Context, cm storage.ChunkManager, schema *schemapb.CollectionSchema, path string, bufferSize int, sep rune) (*reader, error) { + cmReader, err := cm.Reader(ctx, path) + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("read csv file failed, path=%s, err=%s", path, err.Error())) + } + // count, err := estimateReadCountPerBatch(bufferSize, schema) + // if err != nil { + // return nil, err + // } + + // set the interval for determining if the buffer is exceeded + var count int64 = 1000 + + csvReader := csv.NewReader(cmReader) + csvReader.Comma = sep + + header, err := csvReader.Read() + log.Info("csv header parsed", zap.Strings("header", header)) + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to read csv header, error: %v", err)) + } + + rowParser, err := NewRowParser(schema, header) + if err != nil { + return nil, err + } + return &reader{ + ctx: ctx, + cm: cm, + schema: schema, + cr: csvReader, + parser: rowParser, + fileSize: atomic.NewInt64(0), + filePath: path, + bufferSize: bufferSize, + count: count, + }, nil +} + +func (r *reader) Read() (*storage.InsertData, error) { + insertData, err := storage.NewInsertData(r.schema) + if err != nil { + return nil, err + } + var cnt int64 = 0 + for { + value, err := r.cr.Read() + if err == io.EOF || len(value) == 0 { + break + } + row, err := r.parser.Parse(value) + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to parse row, error: %v", err)) + } + err = insertData.Append(row) + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to append row, error: %v", err)) + } + cnt++ + if cnt >= r.count { + cnt = 0 + if insertData.GetMemorySize() >= r.bufferSize { + break + } + } + } + + // finish reading + if insertData.GetRowNum() == 0 { + return nil, io.EOF + } + + return insertData, nil +} + +func (r *reader) Close() {} + +func (r *reader) Size() (int64, error) { + if size := r.fileSize.Load(); size != 0 { + return size, nil + } + size, err := r.cm.Size(r.ctx, r.filePath) + if err != nil { + return 0, err + } + r.fileSize.Store(size) + return size, nil +} + +// func estimateReadCountPerBatch(bufferSize int, schema *schemapb.CollectionSchema) (int64, error) { +// sizePerRecord, err := typeutil.EstimateMaxSizePerRecord(schema) +// if err != nil { +// return 0, err +// } +// if 1000*sizePerRecord <= bufferSize { +// return 1000, nil +// } +// return int64(bufferSize) / int64(sizePerRecord), nil +// } diff --git a/internal/util/importutilv2/csv/reader_test.go b/internal/util/importutilv2/csv/reader_test.go new file mode 100644 index 0000000000000..9f448ac4af0b1 --- /dev/null +++ b/internal/util/importutilv2/csv/reader_test.go @@ -0,0 +1,173 @@ +package csv + +import ( + "context" + "encoding/csv" + "fmt" + "math/rand" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/testutil" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +type ReaderSuite struct { + suite.Suite + + numRows int + pkDataType schemapb.DataType + vecDataType schemapb.DataType +} + +func (suite *ReaderSuite) SetupSuite() { + paramtable.Get().Init(paramtable.NewBaseTable()) +} + +func (suite *ReaderSuite) SetupTest() { + suite.numRows = 10 + suite.pkDataType = schemapb.DataType_Int64 + suite.vecDataType = schemapb.DataType_FloatVector +} + +func (suite *ReaderSuite) run(dataType schemapb.DataType, elemType schemapb.DataType) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + IsPrimaryKey: true, + DataType: suite.pkDataType, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxLengthKey, + Value: "128", + }, + }, + }, + { + FieldID: 101, + Name: "vec", + DataType: suite.vecDataType, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + }, + { + FieldID: 102, + Name: dataType.String(), + DataType: dataType, + ElementType: elemType, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxLengthKey, + Value: "128", + }, + }, + }, + }, + } + + // generate csv data + insertData, err := testutil.CreateInsertData(schema, suite.numRows) + suite.NoError(err) + csvData, err := testutil.CreateInsertDataForCSV(schema, insertData) + suite.NoError(err) + + // write to csv file + sep := '\t' + filePath := fmt.Sprintf("/tmp/test_%d_reader.csv", rand.Int()) + defer os.Remove(filePath) + wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666) + assert.NoError(suite.T(), err) + writer := csv.NewWriter(wf) + writer.Comma = sep + writer.WriteAll(csvData) + suite.NoError(err) + + // read from csv file + ctx := context.Background() + f := storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus_test/test_csv_reader/")) + cm, err := f.NewPersistentStorageChunkManager(ctx) + suite.NoError(err) + + // check reader separate fields by '\t' + wrongSep := ',' + _, err = NewReader(ctx, cm, schema, filePath, 64*1024*1024, wrongSep) + suite.Error(err) + suite.Contains(err.Error(), "value of field is missed: ") + + // check data + reader, err := NewReader(ctx, cm, schema, filePath, 64*1024*1024, sep) + suite.NoError(err) + + checkFn := func(actualInsertData *storage.InsertData, offsetBegin, expectRows int) { + expectInsertData := insertData + for fieldID, data := range actualInsertData.Data { + suite.Equal(expectRows, data.RowNum()) + for i := 0; i < expectRows; i++ { + expect := expectInsertData.Data[fieldID].GetRow(i + offsetBegin) + actual := data.GetRow(i) + suite.Equal(expect, actual) + } + } + } + + res, err := reader.Read() + suite.NoError(err) + checkFn(res, 0, suite.numRows) +} + +func (suite *ReaderSuite) TestReadScalarFields() { + suite.run(schemapb.DataType_Bool, schemapb.DataType_None) + suite.run(schemapb.DataType_Int8, schemapb.DataType_None) + suite.run(schemapb.DataType_Int16, schemapb.DataType_None) + suite.run(schemapb.DataType_Int32, schemapb.DataType_None) + suite.run(schemapb.DataType_Int64, schemapb.DataType_None) + suite.run(schemapb.DataType_Float, schemapb.DataType_None) + suite.run(schemapb.DataType_Double, schemapb.DataType_None) + suite.run(schemapb.DataType_String, schemapb.DataType_None) + suite.run(schemapb.DataType_VarChar, schemapb.DataType_None) + suite.run(schemapb.DataType_JSON, schemapb.DataType_None) + + suite.run(schemapb.DataType_Array, schemapb.DataType_Bool) + suite.run(schemapb.DataType_Array, schemapb.DataType_Int8) + suite.run(schemapb.DataType_Array, schemapb.DataType_Int16) + suite.run(schemapb.DataType_Array, schemapb.DataType_Int32) + suite.run(schemapb.DataType_Array, schemapb.DataType_Int64) + suite.run(schemapb.DataType_Array, schemapb.DataType_Float) + suite.run(schemapb.DataType_Array, schemapb.DataType_Double) + suite.run(schemapb.DataType_Array, schemapb.DataType_String) +} + +func (suite *ReaderSuite) TestStringPK() { + suite.pkDataType = schemapb.DataType_VarChar + suite.run(schemapb.DataType_Int32, schemapb.DataType_None) +} + +func (suite *ReaderSuite) TestVector() { + suite.vecDataType = schemapb.DataType_BinaryVector + suite.run(schemapb.DataType_Int32, schemapb.DataType_None) + suite.vecDataType = schemapb.DataType_FloatVector + suite.run(schemapb.DataType_Int32, schemapb.DataType_None) + suite.vecDataType = schemapb.DataType_Float16Vector + suite.run(schemapb.DataType_Int32, schemapb.DataType_None) + suite.vecDataType = schemapb.DataType_BFloat16Vector + suite.run(schemapb.DataType_Int32, schemapb.DataType_None) + suite.vecDataType = schemapb.DataType_SparseFloatVector + suite.run(schemapb.DataType_Int32, schemapb.DataType_None) +} + +func TestUtil(t *testing.T) { + suite.Run(t, new(ReaderSuite)) +} diff --git a/internal/util/importutilv2/csv/row_parser.go b/internal/util/importutilv2/csv/row_parser.go new file mode 100644 index 0000000000000..3d134628cbbb3 --- /dev/null +++ b/internal/util/importutilv2/csv/row_parser.go @@ -0,0 +1,425 @@ +package csv + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/samber/lo" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type RowParser interface { + Parse(raw []string) (Row, error) +} +type rowParser struct { + header []string + name2Dim map[string]int + name2Field map[string]*schemapb.FieldSchema + pkField *schemapb.FieldSchema + dynamicField *schemapb.FieldSchema +} + +func NewRowParser(schema *schemapb.CollectionSchema, header []string) (RowParser, error) { + name2Field := lo.KeyBy(schema.GetFields(), + func(field *schemapb.FieldSchema) string { + return field.GetName() + }) + + name2Dim := make(map[string]int) + for name, field := range name2Field { + if typeutil.IsVectorType(field.GetDataType()) && !typeutil.IsSparseFloatVectorType(field.GetDataType()) { + dim, err := typeutil.GetDim(field) + if err != nil { + return nil, err + } + name2Dim[name] = int(dim) + } + } + + pkField, err := typeutil.GetPrimaryFieldSchema(schema) + if err != nil { + return nil, err + } + + if pkField.GetAutoID() { + delete(name2Field, pkField.GetName()) + } + + dynamicField := typeutil.GetDynamicField(schema) + if dynamicField != nil { + delete(name2Field, dynamicField.GetName()) + } + + // check if csv header provides the primary key while it should be auto-generated + if pkField.GetAutoID() && lo.Contains(header, pkField.GetName()) { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("the primary key '%s' is auto-generated, no need to provide", pkField.GetName())) + } + + // check whether csv header contains all fields in schema + // except auto generated primary key and dynamic field + nameMap := make(map[string]bool) + for _, name := range header { + nameMap[name] = true + } + for fieldName := range name2Field { + if _, ok := nameMap[fieldName]; !ok && (fieldName != dynamicField.GetName()) && (fieldName != pkField.GetName() && !pkField.GetAutoID()) { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("value of field is missed: '%s'", fieldName)) + } + } + + return &rowParser{ + name2Dim: name2Dim, + header: header, + name2Field: name2Field, + pkField: pkField, + dynamicField: dynamicField, + }, nil +} + +func (r *rowParser) Parse(strArr []string) (Row, error) { + if len(strArr) != len(r.header) { + return nil, merr.WrapErrImportFailed("the number of fields in the row is not equal to the header") + } + + row := make(Row) + dynamicValues := make(map[string]string) + for index, value := range strArr { + if field, ok := r.name2Field[r.header[index]]; ok { + data, err := r.parseEntity(field, value) + if err != nil { + return nil, err + } + row[field.GetFieldID()] = data + } else if r.dynamicField != nil { + dynamicValues[r.header[index]] = value + } else { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the field '%s' is not defined in schema", r.header[index])) + } + } + + // combine the redundant pairs into dynamic field + // for csv which is directly uploaded to minio, it's necessary to check and put the fields not in schema into dynamic field + if r.dynamicField != nil { + err := r.combineDynamicRow(dynamicValues, row) + if err != nil { + return nil, err + } + } + return row, nil +} + +func (r *rowParser) combineDynamicRow(dynamicValues map[string]string, row Row) error { + dynamicFieldID := r.dynamicField.GetFieldID() + MetaName := r.dynamicField.GetName() + if len(dynamicValues) == 0 { + row[dynamicFieldID] = []byte("{}") + return nil + } + + newDynamicValues := make(map[string]any) + if str, ok := dynamicValues[MetaName]; ok { + // parse $meta field to json object + var mp map[string]interface{} + err := json.Unmarshal([]byte(str), &mp) + if err != nil { + return merr.WrapErrImportFailed("illegal value for dynamic field, not a JSON format string") + } + // put the all dynamic fields into newDynamicValues + for k, v := range mp { + if _, ok = dynamicValues[k]; ok { + return merr.WrapErrImportFailed(fmt.Sprintf("duplicated key in dynamic field, key=%s", k)) + } + newDynamicValues[k] = v + } + // remove $meta field from dynamicValues + delete(dynamicValues, MetaName) + } + // put dynamic fields (except $meta) into newDynamicValues + // due to the limit of csv, the number value is stored as string + for k, v := range dynamicValues { + newDynamicValues[k] = v + } + + // check if stasify the json format + dynamicBytes, err := json.Marshal(newDynamicValues) + if err != nil { + return merr.WrapErrImportFailed("illegal value for dynamic field, not a JSON object") + } + row[dynamicFieldID] = dynamicBytes + + return nil +} + +func (r *rowParser) parseEntity(field *schemapb.FieldSchema, obj string) (any, error) { + switch field.GetDataType() { + case schemapb.DataType_Bool: + b, err := strconv.ParseBool(obj) + if err != nil { + return false, r.wrapTypeError(obj, field) + } + return b, nil + case schemapb.DataType_Int8: + num, err := strconv.ParseInt(obj, 10, 8) + if err != nil { + return 0, r.wrapTypeError(obj, field) + } + return int8(num), nil + case schemapb.DataType_Int16: + num, err := strconv.ParseInt(obj, 10, 16) + if err != nil { + return 0, r.wrapTypeError(obj, field) + } + return int16(num), nil + case schemapb.DataType_Int32: + num, err := strconv.ParseInt(obj, 10, 32) + if err != nil { + return 0, r.wrapTypeError(obj, field) + } + return int32(num), nil + case schemapb.DataType_Int64: + num, err := strconv.ParseInt(obj, 10, 64) + if err != nil { + return 0, r.wrapTypeError(obj, field) + } + return num, nil + case schemapb.DataType_Float: + num, err := strconv.ParseFloat(obj, 32) + if err != nil { + return 0, r.wrapTypeError(obj, field) + } + return float32(num), nil + case schemapb.DataType_Double: + num, err := strconv.ParseFloat(obj, 64) + if err != nil { + return 0, r.wrapTypeError(obj, field) + } + return num, nil + case schemapb.DataType_VarChar, schemapb.DataType_String: + return obj, nil + case schemapb.DataType_BinaryVector: + var vec []byte + err := json.Unmarshal([]byte(obj), &vec) + if err != nil { + return nil, r.wrapTypeError(obj, field) + } + if len(vec) != r.name2Dim[field.GetName()]/8 { + return nil, r.wrapDimError(len(vec)*8, field) + } + return vec, nil + case schemapb.DataType_JSON: + var data interface{} + err := json.Unmarshal([]byte(obj), &data) + if err != nil { + return nil, err + } + return []byte(obj), nil + case schemapb.DataType_FloatVector: + var vec []float32 + err := json.Unmarshal([]byte(obj), &vec) + if err != nil { + return nil, r.wrapTypeError(obj, field) + } + if len(vec) != r.name2Dim[field.GetName()] { + return nil, r.wrapDimError(len(vec), field) + } + return vec, nil + case schemapb.DataType_Float16Vector: + var vec []float32 + err := json.Unmarshal([]byte(obj), &vec) + if err != nil { + return nil, r.wrapTypeError(obj, field) + } + if len(vec) != r.name2Dim[field.GetName()] { + return nil, r.wrapDimError(len(vec), field) + } + vec2 := make([]byte, len(vec)*2) + for i := 0; i < len(vec); i++ { + copy(vec2[i*2:], typeutil.Float32ToFloat16Bytes(vec[i])) + } + return vec2, nil + case schemapb.DataType_BFloat16Vector: + var vec []float32 + err := json.Unmarshal([]byte(obj), &vec) + if err != nil { + return nil, r.wrapTypeError(obj, field) + } + if len(vec) != r.name2Dim[field.GetName()] { + return nil, r.wrapDimError(len(vec), field) + } + vec2 := make([]byte, len(vec)*2) + for i := 0; i < len(vec); i++ { + copy(vec2[i*2:], typeutil.Float32ToBFloat16Bytes(vec[i])) + } + return vec2, nil + case schemapb.DataType_SparseFloatVector: + // use dec.UseNumber() to avoid float64 precision loss + var vec map[string]interface{} + dec := json.NewDecoder(strings.NewReader(obj)) + dec.UseNumber() + err := dec.Decode(&vec) + if err != nil { + return nil, r.wrapTypeError(obj, field) + } + vec2, err := typeutil.CreateSparseFloatRowFromMap(vec) + if err != nil { + return nil, err + } + return vec2, nil + case schemapb.DataType_Array: + var vec []interface{} + desc := json.NewDecoder(strings.NewReader(obj)) + desc.UseNumber() + err := desc.Decode(&vec) + if err != nil { + return nil, r.wrapTypeError(obj, field) + } + scalarFieldData, err := r.arrayToFieldData(vec, field.GetElementType()) + if err != nil { + return nil, err + } + return scalarFieldData, nil + default: + return nil, merr.WrapErrImportFailed(fmt.Sprintf("parse csv failed, unsupport data type: %s", + field.GetDataType().String())) + } +} + +func (r *rowParser) arrayToFieldData(arr []interface{}, eleType schemapb.DataType) (*schemapb.ScalarField, error) { + switch eleType { + case schemapb.DataType_Bool: + values := make([]bool, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(bool) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + values = append(values, value) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: values, + }, + }, + }, nil + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: + values := make([]int32, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(json.Number) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + num, err := strconv.ParseInt(value.String(), 10, 32) + if err != nil { + return nil, err + } + values = append(values, int32(num)) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: values, + }, + }, + }, nil + case schemapb.DataType_Int64: + values := make([]int64, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(json.Number) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + num, err := strconv.ParseInt(value.String(), 10, 64) + if err != nil { + return nil, err + } + values = append(values, num) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: values, + }, + }, + }, nil + case schemapb.DataType_Float: + values := make([]float32, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(json.Number) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + num, err := strconv.ParseFloat(value.String(), 32) + if err != nil { + return nil, err + } + values = append(values, float32(num)) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: values, + }, + }, + }, nil + case schemapb.DataType_Double: + values := make([]float64, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(json.Number) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + num, err := strconv.ParseFloat(value.String(), 64) + if err != nil { + return nil, err + } + values = append(values, num) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: values, + }, + }, + }, nil + case schemapb.DataType_VarChar, schemapb.DataType_String: + values := make([]string, 0) + for i := 0; i < len(arr); i++ { + value, ok := arr[i].(string) + if !ok { + return nil, r.wrapArrayValueTypeError(arr, eleType) + } + values = append(values, value) + } + return &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: values, + }, + }, + }, nil + default: + return nil, merr.WrapErrImportFailed(fmt.Sprintf("parse csv failed, unsupport data type: %s", eleType.String())) + } +} + +func (r *rowParser) wrapTypeError(v any, field *schemapb.FieldSchema) error { + return merr.WrapErrImportFailed(fmt.Sprintf("expected type '%s' for field '%s', got type '%T' with value '%v'", + field.GetDataType().String(), field.GetName(), v, v)) +} + +func (r *rowParser) wrapDimError(actualDim int, field *schemapb.FieldSchema) error { + return merr.WrapErrImportFailed(fmt.Sprintf("expected dim '%d' for field '%s' with type '%s', got dim '%d'", + r.name2Dim[field.GetName()], field.GetName(), field.GetDataType().String(), actualDim)) +} + +func (r *rowParser) wrapArrayValueTypeError(v any, eleType schemapb.DataType) error { + return merr.WrapErrImportFailed(fmt.Sprintf("expected element type '%s' in array field, got type '%T' with value '%v'", + eleType.String(), v, v)) +} diff --git a/internal/util/importutilv2/csv/row_parser_test.go b/internal/util/importutilv2/csv/row_parser_test.go new file mode 100644 index 0000000000000..3c74fc195fc6f --- /dev/null +++ b/internal/util/importutilv2/csv/row_parser_test.go @@ -0,0 +1,173 @@ +package csv + +import ( + "encoding/json" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" +) + +func TestNewRowParser_Invalid(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1, + Name: "id", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 2, + Name: "vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "2"}}, + }, + { + FieldID: 3, + Name: "str", + DataType: schemapb.DataType_VarChar, + }, + { + FieldID: 4, + Name: "$meta", + IsDynamic: true, + DataType: schemapb.DataType_JSON, + }, + }, + } + + type testCase struct { + header []string + expectErr string + } + + cases := []testCase{ + {header: []string{"id", "vector", "$meta"}, expectErr: "value of field is missed: 'str'"}, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) { + _, err := NewRowParser(schema, c.header) + assert.Error(t, err) + assert.True(t, strings.Contains(err.Error(), c.expectErr)) + }) + } +} + +func TestRowParser_Parse_Valid(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1, + Name: "id", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 2, + Name: "vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "2"}}, + }, + { + FieldID: 3, + Name: "$meta", + IsDynamic: true, + DataType: schemapb.DataType_JSON, + }, + }, + } + + type testCase struct { + header []string + row []string + dyFields map[string]any // expect dynamic fields + } + + cases := []testCase{ + {header: []string{"id", "vector", "$meta"}, row: []string{"1", "[1, 2]", "{\"y\": 2}"}, dyFields: map[string]any{"y": 2.0}}, + {header: []string{"id", "vector", "x", "$meta"}, row: []string{"1", "[1, 2]", "8", "{\"y\": 2}"}, dyFields: map[string]any{"x": "8", "y": 2.0}}, + {header: []string{"id", "vector", "x", "$meta"}, row: []string{"1", "[1, 2]", "8", "{}"}, dyFields: map[string]any{"x": "8"}}, + {header: []string{"id", "vector", "x"}, row: []string{"1", "[1, 2]", "8"}, dyFields: map[string]any{"x": "8"}}, + {header: []string{"id", "vector", "str", "$meta"}, row: []string{"1", "[1, 2]", "xxsddsffwq", "{\"y\": 2}"}, dyFields: map[string]any{"y": 2.0, "str": "xxsddsffwq"}}, + } + + for i, c := range cases { + r, err := NewRowParser(schema, c.header) + assert.NoError(t, err) + t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) { + data, err := r.Parse(c.row) + assert.NoError(t, err) + + // validate contains fields + for _, field := range schema.GetFields() { + _, ok := data[field.GetFieldID()] + assert.True(t, ok) + } + + // validate dynamic fields + var dynamicFields map[string]interface{} + err = json.Unmarshal(data[r.(*rowParser).dynamicField.GetFieldID()].([]byte), &dynamicFields) + assert.NoError(t, err) + assert.Len(t, dynamicFields, len(c.dyFields)) + for k, v := range c.dyFields { + rv, ok := dynamicFields[k] + assert.True(t, ok) + assert.EqualValues(t, rv, v) + } + }) + } +} + +func TestRowParser_Parse_Invalid(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1, + Name: "id", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 2, + Name: "vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "2"}}, + }, + { + FieldID: 3, + Name: "$meta", + IsDynamic: true, + DataType: schemapb.DataType_JSON, + }, + }, + } + + type testCase struct { + header []string + row []string + expectErr string + } + + cases := []testCase{ + {header: []string{"id", "vector", "x", "$meta"}, row: []string{"1", "[1, 2]", "6", "{\"x\": 8}"}, expectErr: "duplicated key in dynamic field, key=x"}, + {header: []string{"id", "vector", "x", "$meta"}, row: []string{"1", "[1, 2]", "8", "{*&%%&$*(&}"}, expectErr: "illegal value for dynamic field, not a JSON format string"}, + {header: []string{"id", "vector", "x", "$meta"}, row: []string{"1", "[1, 2]", "8"}, expectErr: "the number of fields in the row is not equal to the header"}, + } + + for i, c := range cases { + r, err := NewRowParser(schema, c.header) + assert.NoError(t, err) + t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) { + _, err := r.Parse(c.row) + assert.Error(t, err) + assert.True(t, strings.Contains(err.Error(), c.expectErr)) + }) + } +} diff --git a/internal/util/importutilv2/option.go b/internal/util/importutilv2/option.go index b8e958c19fdb2..55fef37794437 100644 --- a/internal/util/importutilv2/option.go +++ b/internal/util/importutilv2/option.go @@ -22,6 +22,8 @@ import ( "strconv" "strings" + "github.com/samber/lo" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -99,3 +101,15 @@ func SkipDiskQuotaCheck(options Options) bool { } return true } + +func GetCSVSep(options Options) (rune, error) { + sep, err := funcutil.GetAttrByKeyFromRepeatedKV("sep", options) + unsupportedSep := []rune{0, '\n', '\r', '"'} + defaultSep := ',' + if err != nil || len(sep) == 0 { + return defaultSep, nil + } else if lo.Contains(unsupportedSep, []rune(sep)[0]) { + return 0, merr.WrapErrImportFailed(fmt.Sprintf("unsupported csv separator: %s", sep)) + } + return []rune(sep)[0], nil +} diff --git a/internal/util/importutilv2/reader.go b/internal/util/importutilv2/reader.go index de142feca1395..443ed3c11d224 100644 --- a/internal/util/importutilv2/reader.go +++ b/internal/util/importutilv2/reader.go @@ -23,6 +23,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/importutilv2/binlog" + "github.com/milvus-io/milvus/internal/util/importutilv2/csv" "github.com/milvus-io/milvus/internal/util/importutilv2/json" "github.com/milvus-io/milvus/internal/util/importutilv2/numpy" "github.com/milvus-io/milvus/internal/util/importutilv2/parquet" @@ -70,6 +71,12 @@ func NewReader(ctx context.Context, return numpy.NewReader(ctx, cm, schema, importFile.GetPaths(), bufferSize) case Parquet: return parquet.NewReader(ctx, cm, schema, importFile.GetPaths()[0], bufferSize) + case CSV: + sep, err := GetCSVSep(options) + if err != nil { + return nil, err + } + return csv.NewReader(ctx, cm, schema, importFile.GetPaths()[0], bufferSize, sep) } return nil, merr.WrapErrImportFailed("unexpected import file") } diff --git a/internal/util/importutilv2/util.go b/internal/util/importutilv2/util.go index 0f4f7e2a2fed5..0e4f5539cb6af 100644 --- a/internal/util/importutilv2/util.go +++ b/internal/util/importutilv2/util.go @@ -33,10 +33,12 @@ const ( JSON FileType = 1 Numpy FileType = 2 Parquet FileType = 3 + CSV FileType = 4 JSONFileExt = ".json" NumpyFileExt = ".npy" ParquetFileExt = ".parquet" + CSVFileExt = ".csv" ) var FileTypeName = map[int]string{ @@ -44,6 +46,7 @@ var FileTypeName = map[int]string{ 1: "JSON", 2: "Numpy", 3: "Parquet", + 4: "CSV", } func (f FileType) String() string { @@ -80,6 +83,11 @@ func GetFileType(file *internalpb.ImportFile) (FileType, error) { return Invalid, merr.WrapErrImportFailed("for Parquet import, accepts only one file") } return Parquet, nil + case CSVFileExt: + if len(file.GetPaths()) != 1 { + return Invalid, merr.WrapErrImportFailed("for CSV import, accepts only one file") + } + return CSV, nil } return Invalid, merr.WrapErrImportFailed(fmt.Sprintf("unexpect file type, files=%v", file.GetPaths())) } diff --git a/internal/util/testutil/test_util.go b/internal/util/testutil/test_util.go index a1a19b03ca76b..7a20031522c08 100644 --- a/internal/util/testutil/test_util.go +++ b/internal/util/testutil/test_util.go @@ -570,3 +570,104 @@ func CreateInsertDataRowsForJSON(schema *schemapb.CollectionSchema, insertData * return rows, nil } + +func CreateInsertDataForCSV(schema *schemapb.CollectionSchema, insertData *storage.InsertData) ([][]string, error) { + rowNum := insertData.GetRowNum() + csvData := make([][]string, 0, rowNum+1) + + header := make([]string, 0) + nameToFields := lo.KeyBy(schema.GetFields(), func(field *schemapb.FieldSchema) string { + name := field.GetName() + if !field.GetAutoID() { + header = append(header, name) + } + return name + }) + csvData = append(csvData, header) + + for i := 0; i < rowNum; i++ { + data := make([]string, 0) + for _, name := range header { + field := nameToFields[name] + value := insertData.Data[field.FieldID] + dataType := field.GetDataType() + elemType := field.GetElementType() + if field.GetAutoID() { + continue + } + switch dataType { + case schemapb.DataType_Array: + var arr any + switch elemType { + case schemapb.DataType_Bool: + arr = value.GetRow(i).(*schemapb.ScalarField).GetBoolData().GetData() + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: + arr = value.GetRow(i).(*schemapb.ScalarField).GetIntData().GetData() + case schemapb.DataType_Int64: + arr = value.GetRow(i).(*schemapb.ScalarField).GetLongData().GetData() + case schemapb.DataType_Float: + arr = value.GetRow(i).(*schemapb.ScalarField).GetFloatData().GetData() + case schemapb.DataType_Double: + arr = value.GetRow(i).(*schemapb.ScalarField).GetDoubleData().GetData() + case schemapb.DataType_String: + arr = value.GetRow(i).(*schemapb.ScalarField).GetStringData().GetData() + } + j, err := json.Marshal(arr) + if err != nil { + return nil, err + } + data = append(data, string(j)) + case schemapb.DataType_JSON: + data = append(data, string(value.GetRow(i).([]byte))) + case schemapb.DataType_FloatVector: + vec := value.GetRow(i).([]float32) + j, err := json.Marshal(vec) + if err != nil { + return nil, err + } + data = append(data, string(j)) + case schemapb.DataType_BinaryVector: + bytes := value.GetRow(i).([]byte) + vec := make([]int, 0, len(bytes)) + for _, b := range bytes { + vec = append(vec, int(b)) + } + j, err := json.Marshal(vec) + if err != nil { + return nil, err + } + data = append(data, string(j)) + case schemapb.DataType_Float16Vector: + bytes := value.GetRow(i).([]byte) + vec := typeutil.Float16BytesToFloat32Vector(bytes) + j, err := json.Marshal(vec) + if err != nil { + return nil, err + } + data = append(data, string(j)) + case schemapb.DataType_BFloat16Vector: + bytes := value.GetRow(i).([]byte) + vec := typeutil.BFloat16BytesToFloat32Vector(bytes) + j, err := json.Marshal(vec) + if err != nil { + return nil, err + } + data = append(data, string(j)) + case schemapb.DataType_SparseFloatVector: + bytes := value.GetRow(i).([]byte) + m := typeutil.SparseFloatBytesToMap(bytes) + j, err := json.Marshal(m) + if err != nil { + return nil, err + } + data = append(data, string(j)) + default: + str := fmt.Sprintf("%v", value.GetRow(i)) + data = append(data, str) + } + } + csvData = append(csvData, data) + } + + return csvData, nil +} diff --git a/tests/integration/import/dynamic_field_test.go b/tests/integration/import/dynamic_field_test.go index f714176af0060..4c928df891bb9 100644 --- a/tests/integration/import/dynamic_field_test.go +++ b/tests/integration/import/dynamic_field_test.go @@ -99,6 +99,8 @@ func (s *BulkInsertSuite) testImportDynamicField() { err = os.MkdirAll(c.ChunkManager.RootPath(), os.ModePerm) s.NoError(err) + options := []*commonpb.KeyValuePair{} + switch s.fileType { case importutilv2.Numpy: importFile, err := GenerateNumpyFiles(c.ChunkManager, schema, rowCount) @@ -130,11 +132,25 @@ func (s *BulkInsertSuite) testImportDynamicField() { }, }, } + case importutilv2.CSV: + filePath := fmt.Sprintf("/tmp/test_%d.csv", rand.Int()) + sep := GenerateCSVFile(s.T(), filePath, schema, rowCount) + defer os.Remove(filePath) + options = []*commonpb.KeyValuePair{{Key: "sep", Value: string(sep)}} + s.NoError(err) + files = []*internalpb.ImportFile{ + { + Paths: []string{ + filePath, + }, + }, + } } importResp, err := c.Proxy.ImportV2(ctx, &internalpb.ImportRequest{ CollectionName: collectionName, Files: files, + Options: options, }) s.NoError(err) s.Equal(int32(0), importResp.GetStatus().GetCode()) @@ -197,3 +213,8 @@ func (s *BulkInsertSuite) TestImportDynamicField_Parquet() { s.fileType = importutilv2.Parquet s.testImportDynamicField() } + +func (s *BulkInsertSuite) TestImportDynamicField_CSV() { + s.fileType = importutilv2.CSV + s.testImportDynamicField() +} diff --git a/tests/integration/import/import_test.go b/tests/integration/import/import_test.go index 4e8c4f85f3082..0eb43b55f8e3f 100644 --- a/tests/integration/import/import_test.go +++ b/tests/integration/import/import_test.go @@ -107,6 +107,8 @@ func (s *BulkInsertSuite) run() { err = os.MkdirAll(c.ChunkManager.RootPath(), os.ModePerm) s.NoError(err) + options := []*commonpb.KeyValuePair{} + switch s.fileType { case importutilv2.Numpy: importFile, err := GenerateNumpyFiles(c.ChunkManager, schema, rowCount) @@ -135,11 +137,25 @@ func (s *BulkInsertSuite) run() { }, }, } + case importutilv2.CSV: + filePath := fmt.Sprintf("/tmp/test_%d.csv", rand.Int()) + sep := GenerateCSVFile(s.T(), filePath, schema, rowCount) + defer os.Remove(filePath) + options = []*commonpb.KeyValuePair{{Key: "sep", Value: string(sep)}} + s.NoError(err) + files = []*internalpb.ImportFile{ + { + Paths: []string{ + filePath, + }, + }, + } } importResp, err := c.Proxy.ImportV2(ctx, &internalpb.ImportRequest{ CollectionName: collectionName, Files: files, + Options: options, }) s.NoError(err) s.Equal(int32(0), importResp.GetStatus().GetCode()) @@ -203,7 +219,7 @@ func (s *BulkInsertSuite) run() { } func (s *BulkInsertSuite) TestMultiFileTypes() { - fileTypeArr := []importutilv2.FileType{importutilv2.JSON, importutilv2.Numpy, importutilv2.Parquet} + fileTypeArr := []importutilv2.FileType{importutilv2.JSON, importutilv2.Numpy, importutilv2.Parquet, importutilv2.CSV} for _, fileType := range fileTypeArr { s.fileType = fileType diff --git a/tests/integration/import/multi_vector_test.go b/tests/integration/import/multi_vector_test.go index 7738853a56bd3..b6920f099ed46 100644 --- a/tests/integration/import/multi_vector_test.go +++ b/tests/integration/import/multi_vector_test.go @@ -123,6 +123,8 @@ func (s *BulkInsertSuite) testMultipleVectorFields() { err = os.MkdirAll(c.ChunkManager.RootPath(), os.ModePerm) s.NoError(err) + options := []*commonpb.KeyValuePair{} + switch s.fileType { case importutilv2.Numpy: importFile, err := GenerateNumpyFiles(c.ChunkManager, schema, rowCount) @@ -154,11 +156,25 @@ func (s *BulkInsertSuite) testMultipleVectorFields() { }, }, } + case importutilv2.CSV: + filePath := fmt.Sprintf("/tmp/test_%d.csv", rand.Int()) + sep := GenerateCSVFile(s.T(), filePath, schema, rowCount) + defer os.Remove(filePath) + options = []*commonpb.KeyValuePair{{Key: "sep", Value: string(sep)}} + s.NoError(err) + files = []*internalpb.ImportFile{ + { + Paths: []string{ + filePath, + }, + }, + } } importResp, err := c.Proxy.ImportV2(ctx, &internalpb.ImportRequest{ CollectionName: collectionName, Files: files, + Options: options, }) s.NoError(err) s.Equal(int32(0), importResp.GetStatus().GetCode()) @@ -226,3 +242,8 @@ func (s *BulkInsertSuite) TestMultipleVectorFields_Parquet() { s.fileType = importutilv2.Parquet s.testMultipleVectorFields() } + +func (s *BulkInsertSuite) TestMultipleVectorFields_CSV() { + s.fileType = importutilv2.CSV + s.testMultipleVectorFields() +} diff --git a/tests/integration/import/util_test.go b/tests/integration/import/util_test.go index c9901ba2ad8bf..58eb84ab56523 100644 --- a/tests/integration/import/util_test.go +++ b/tests/integration/import/util_test.go @@ -18,6 +18,7 @@ package importv2 import ( "context" + "encoding/csv" "encoding/json" "fmt" "os" @@ -202,6 +203,26 @@ func GenerateJSONFile(t *testing.T, filePath string, schema *schemapb.Collection assert.NoError(t, err) } +func GenerateCSVFile(t *testing.T, filePath string, schema *schemapb.CollectionSchema, count int) rune { + insertData, err := testutil.CreateInsertData(schema, count) + assert.NoError(t, err) + + csvData, err := testutil.CreateInsertDataForCSV(schema, insertData) + assert.NoError(t, err) + + sep := ',' + wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666) + assert.NoError(t, err) + + writer := csv.NewWriter(wf) + writer.Comma = sep + writer.WriteAll(csvData) + writer.Flush() + assert.NoError(t, err) + + return sep +} + func WaitForImportDone(ctx context.Context, c *integration.MiniClusterV2, jobID string) error { for { resp, err := c.Proxy.GetImportProgress(ctx, &internalpb.GetImportProgressRequest{