From a39bce40a8bd9deca3730a97799f2999908b737c Mon Sep 17 00:00:00 2001 From: OxalisCu <2127298698@qq.com> Date: Mon, 29 Jul 2024 04:07:53 +0000 Subject: [PATCH] feat: support multiple vectors in csv import --- internal/util/importutilv2/csv/reader.go | 6 +-- internal/util/importutilv2/csv/row_parser.go | 48 ++++++++------------ internal/util/importutilv2/option.go | 12 +---- 3 files changed, 25 insertions(+), 41 deletions(-) diff --git a/internal/util/importutilv2/csv/reader.go b/internal/util/importutilv2/csv/reader.go index ee0cb6faac492..dd5e13e49b784 100644 --- a/internal/util/importutilv2/csv/reader.go +++ b/internal/util/importutilv2/csv/reader.go @@ -5,13 +5,14 @@ import ( "encoding/csv" "fmt" "io" - "strings" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" "go.uber.org/atomic" + "go.uber.org/zap" ) type Row = map[storage.FieldID]any @@ -43,8 +44,7 @@ func NewReader(ctx context.Context, cm storage.ChunkManager, schema *schemapb.Co csvReader.Comma = sep header, err := csvReader.Read() - fmt.Printf("csv header: %v\n", strings.Join(header, "|")) - fmt.Printf("length: %v\n", len(header)) + log.Info("csv header parsed", zap.Strings("header", header)) if err != nil { return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to read csv header, error: %v", err)) } diff --git a/internal/util/importutilv2/csv/row_parser.go b/internal/util/importutilv2/csv/row_parser.go index 231197335754c..5b27c0167fb01 100644 --- a/internal/util/importutilv2/csv/row_parser.go +++ b/internal/util/importutilv2/csv/row_parser.go @@ -2,7 +2,6 @@ package csv import ( "encoding/json" - "errors" "fmt" "strconv" "strings" @@ -17,8 +16,8 @@ type RowParser interface { Parse(raw []string) (Row, error) } type rowParser struct { - dim int header []string + name2Dim map[string]int name2Field map[string]*schemapb.FieldSchema pkField *schemapb.FieldSchema dynamicField *schemapb.FieldSchema @@ -30,17 +29,17 @@ func NewRowParser(schema *schemapb.CollectionSchema, header []string) (RowParser return field.GetName() }) - vecField, err := typeutil.GetVectorFieldSchema(schema) - if err != nil { - return nil, err - } - dim := int64(0) - if typeutil.IsVectorType(vecField.DataType) && !typeutil.IsSparseFloatVectorType(vecField.DataType) { - dim, err = typeutil.GetDim(vecField) - if err != nil { - return nil, err + name2Dim := make(map[string]int) + for name, field := range name2Field { + if typeutil.IsVectorType(field.GetDataType()) && !typeutil.IsSparseFloatVectorType(field.GetDataType()) { + dim, err := typeutil.GetDim(field) + if err != nil { + return nil, err + } + name2Dim[name] = int(dim) } } + pkField, err := typeutil.GetPrimaryFieldSchema(schema) if err != nil { return nil, err @@ -56,14 +55,7 @@ func NewRowParser(schema *schemapb.CollectionSchema, header []string) (RowParser } // check if csv header provides the primary key while it should be auto-generated - containsPk := false - for _, v := range header { - if v == pkField.GetName() { - containsPk = true - break - } - } - if pkField.GetAutoID() && containsPk { + if pkField.GetAutoID() && lo.Contains(header, pkField.GetName()) { return nil, merr.WrapErrImportFailed( fmt.Sprintf("the primary key '%s' is auto-generated, no need to provide", pkField.GetName())) } @@ -81,7 +73,7 @@ func NewRowParser(schema *schemapb.CollectionSchema, header []string) (RowParser } return &rowParser{ - dim: int(dim), + name2Dim: name2Dim, header: header, name2Field: name2Field, pkField: pkField, @@ -215,7 +207,7 @@ func (r *rowParser) parseEntity(field *schemapb.FieldSchema, obj string) (any, e if err != nil { return nil, r.wrapTypeError(obj, field) } - if len(vec) != r.dim/8 { + if len(vec) != r.name2Dim[field.GetName()]/8 { return nil, r.wrapDimError(len(vec)*8, field) } for i := 0; i < len(vec); i++ { @@ -235,7 +227,7 @@ func (r *rowParser) parseEntity(field *schemapb.FieldSchema, obj string) (any, e if err != nil { return nil, r.wrapTypeError(obj, field) } - if len(vec) != r.dim { + if len(vec) != r.name2Dim[field.GetName()] { return nil, r.wrapDimError(len(vec), field) } for i := 0; i < len(vec); i++ { @@ -248,7 +240,7 @@ func (r *rowParser) parseEntity(field *schemapb.FieldSchema, obj string) (any, e if err != nil { return nil, r.wrapTypeError(obj, field) } - if len(vec) != r.dim { + if len(vec) != r.name2Dim[field.GetName()] { return nil, r.wrapDimError(len(vec), field) } vec2 := make([]byte, len(vec)*2) @@ -262,7 +254,7 @@ func (r *rowParser) parseEntity(field *schemapb.FieldSchema, obj string) (any, e if err != nil { return nil, r.wrapTypeError(obj, field) } - if len(vec) != r.dim { + if len(vec) != r.name2Dim[field.GetName()] { return nil, r.wrapDimError(len(vec), field) } vec2 := make([]byte, len(vec)*2) @@ -328,7 +320,7 @@ func (r *rowParser) arrayToFieldData(arr []interface{}, eleType schemapb.DataTyp if !ok { return nil, r.wrapArrayValueTypeError(arr, eleType) } - num, err := strconv.ParseInt(value.String(), 0, 32) + num, err := strconv.ParseInt(value.String(), 10, 32) if err != nil { return nil, err } @@ -348,7 +340,7 @@ func (r *rowParser) arrayToFieldData(arr []interface{}, eleType schemapb.DataTyp if !ok { return nil, r.wrapArrayValueTypeError(arr, eleType) } - num, err := strconv.ParseInt(value.String(), 0, 64) + num, err := strconv.ParseInt(value.String(), 10, 64) if err != nil { return nil, err } @@ -418,7 +410,7 @@ func (r *rowParser) arrayToFieldData(arr []interface{}, eleType schemapb.DataTyp }, }, nil default: - return nil, errors.New(fmt.Sprintf("unsupported array data type '%s'", eleType.String())) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("parse csv failed, unsupport data type: %s", eleType.String())) } } @@ -429,7 +421,7 @@ func (r *rowParser) wrapTypeError(v any, field *schemapb.FieldSchema) error { func (r *rowParser) wrapDimError(actualDim int, field *schemapb.FieldSchema) error { return merr.WrapErrImportFailed(fmt.Sprintf("expected dim '%d' for field '%s' with type '%s', got dim '%d'", - r.dim, field.GetName(), field.GetDataType().String(), actualDim)) + r.name2Dim[field.GetName()], field.GetName(), field.GetDataType().String(), actualDim)) } func (r *rowParser) wrapArrayValueTypeError(v any, eleType schemapb.DataType) error { diff --git a/internal/util/importutilv2/option.go b/internal/util/importutilv2/option.go index 60967de08db8d..664c0c9628ac6 100644 --- a/internal/util/importutilv2/option.go +++ b/internal/util/importutilv2/option.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/samber/lo" ) const ( @@ -86,22 +87,13 @@ func IsL0Import(options Options) bool { return true } -func containsSep(r rune, rs []rune) bool { - for _, v := range rs { - if v == r { - return true - } - } - return false -} - func GetCSVSep(options Options) (rune, error) { sep, err := funcutil.GetAttrByKeyFromRepeatedKV("sep", options) unsupportedSep := []rune{0, '\n', '\r', '"'} defaultSep := ',' if err != nil || len(sep) == 0 { return defaultSep, nil - } else if containsSep([]rune(sep)[0], unsupportedSep) { + } else if lo.Contains(unsupportedSep, []rune(sep)[0]) { return 0, merr.WrapErrImportFailed(fmt.Sprintf("unsupported csv seperator: %s", sep)) } return []rune(sep)[0], nil