Skip to content

Commit

Permalink
feat: support multiple vectors in csv import
Browse files Browse the repository at this point in the history
  • Loading branch information
OxalisCu committed Jul 29, 2024
1 parent 580428e commit a39bce4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 41 deletions.
6 changes: 3 additions & 3 deletions internal/util/importutilv2/csv/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ import (
"encoding/csv"
"fmt"
"io"
"strings"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"go.uber.org/atomic"
"go.uber.org/zap"
)

type Row = map[storage.FieldID]any
Expand Down Expand Up @@ -43,8 +44,7 @@ func NewReader(ctx context.Context, cm storage.ChunkManager, schema *schemapb.Co
csvReader.Comma = sep

header, err := csvReader.Read()
fmt.Printf("csv header: %v\n", strings.Join(header, "|"))
fmt.Printf("length: %v\n", len(header))
log.Info("csv header parsed", zap.Strings("header", header))
if err != nil {
return nil, merr.WrapErrImportFailed(fmt.Sprintf("failed to read csv header, error: %v", err))
}
Expand Down
48 changes: 20 additions & 28 deletions internal/util/importutilv2/csv/row_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package csv

import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
Expand All @@ -17,8 +16,8 @@ type RowParser interface {
Parse(raw []string) (Row, error)
}
type rowParser struct {
dim int
header []string
name2Dim map[string]int
name2Field map[string]*schemapb.FieldSchema
pkField *schemapb.FieldSchema
dynamicField *schemapb.FieldSchema
Expand All @@ -30,17 +29,17 @@ func NewRowParser(schema *schemapb.CollectionSchema, header []string) (RowParser
return field.GetName()
})

vecField, err := typeutil.GetVectorFieldSchema(schema)
if err != nil {
return nil, err
}
dim := int64(0)
if typeutil.IsVectorType(vecField.DataType) && !typeutil.IsSparseFloatVectorType(vecField.DataType) {
dim, err = typeutil.GetDim(vecField)
if err != nil {
return nil, err
name2Dim := make(map[string]int)
for name, field := range name2Field {
if typeutil.IsVectorType(field.GetDataType()) && !typeutil.IsSparseFloatVectorType(field.GetDataType()) {
dim, err := typeutil.GetDim(field)
if err != nil {
return nil, err
}
name2Dim[name] = int(dim)
}
}

pkField, err := typeutil.GetPrimaryFieldSchema(schema)
if err != nil {
return nil, err
Expand All @@ -56,14 +55,7 @@ func NewRowParser(schema *schemapb.CollectionSchema, header []string) (RowParser
}

// check if csv header provides the primary key while it should be auto-generated
containsPk := false
for _, v := range header {
if v == pkField.GetName() {
containsPk = true
break
}
}
if pkField.GetAutoID() && containsPk {
if pkField.GetAutoID() && lo.Contains(header, pkField.GetName()) {
return nil, merr.WrapErrImportFailed(
fmt.Sprintf("the primary key '%s' is auto-generated, no need to provide", pkField.GetName()))
}
Expand All @@ -81,7 +73,7 @@ func NewRowParser(schema *schemapb.CollectionSchema, header []string) (RowParser
}

return &rowParser{
dim: int(dim),
name2Dim: name2Dim,
header: header,
name2Field: name2Field,
pkField: pkField,
Expand Down Expand Up @@ -215,7 +207,7 @@ func (r *rowParser) parseEntity(field *schemapb.FieldSchema, obj string) (any, e
if err != nil {
return nil, r.wrapTypeError(obj, field)
}
if len(vec) != r.dim/8 {
if len(vec) != r.name2Dim[field.GetName()]/8 {
return nil, r.wrapDimError(len(vec)*8, field)
}
for i := 0; i < len(vec); i++ {
Expand All @@ -235,7 +227,7 @@ func (r *rowParser) parseEntity(field *schemapb.FieldSchema, obj string) (any, e
if err != nil {
return nil, r.wrapTypeError(obj, field)
}
if len(vec) != r.dim {
if len(vec) != r.name2Dim[field.GetName()] {
return nil, r.wrapDimError(len(vec), field)
}
for i := 0; i < len(vec); i++ {
Expand All @@ -248,7 +240,7 @@ func (r *rowParser) parseEntity(field *schemapb.FieldSchema, obj string) (any, e
if err != nil {
return nil, r.wrapTypeError(obj, field)
}
if len(vec) != r.dim {
if len(vec) != r.name2Dim[field.GetName()] {
return nil, r.wrapDimError(len(vec), field)
}
vec2 := make([]byte, len(vec)*2)
Expand All @@ -262,7 +254,7 @@ func (r *rowParser) parseEntity(field *schemapb.FieldSchema, obj string) (any, e
if err != nil {
return nil, r.wrapTypeError(obj, field)
}
if len(vec) != r.dim {
if len(vec) != r.name2Dim[field.GetName()] {
return nil, r.wrapDimError(len(vec), field)
}
vec2 := make([]byte, len(vec)*2)
Expand Down Expand Up @@ -328,7 +320,7 @@ func (r *rowParser) arrayToFieldData(arr []interface{}, eleType schemapb.DataTyp
if !ok {
return nil, r.wrapArrayValueTypeError(arr, eleType)
}
num, err := strconv.ParseInt(value.String(), 0, 32)
num, err := strconv.ParseInt(value.String(), 10, 32)
if err != nil {
return nil, err
}
Expand All @@ -348,7 +340,7 @@ func (r *rowParser) arrayToFieldData(arr []interface{}, eleType schemapb.DataTyp
if !ok {
return nil, r.wrapArrayValueTypeError(arr, eleType)
}
num, err := strconv.ParseInt(value.String(), 0, 64)
num, err := strconv.ParseInt(value.String(), 10, 64)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -418,7 +410,7 @@ func (r *rowParser) arrayToFieldData(arr []interface{}, eleType schemapb.DataTyp
},
}, nil
default:
return nil, errors.New(fmt.Sprintf("unsupported array data type '%s'", eleType.String()))
return nil, merr.WrapErrImportFailed(fmt.Sprintf("parse csv failed, unsupport data type: %s", eleType.String()))
}
}

Expand All @@ -429,7 +421,7 @@ func (r *rowParser) wrapTypeError(v any, field *schemapb.FieldSchema) error {

func (r *rowParser) wrapDimError(actualDim int, field *schemapb.FieldSchema) error {
return merr.WrapErrImportFailed(fmt.Sprintf("expected dim '%d' for field '%s' with type '%s', got dim '%d'",
r.dim, field.GetName(), field.GetDataType().String(), actualDim))
r.name2Dim[field.GetName()], field.GetName(), field.GetDataType().String(), actualDim))
}

func (r *rowParser) wrapArrayValueTypeError(v any, eleType schemapb.DataType) error {
Expand Down
12 changes: 2 additions & 10 deletions internal/util/importutilv2/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/samber/lo"
)

const (
Expand Down Expand Up @@ -86,22 +87,13 @@ func IsL0Import(options Options) bool {
return true
}

func containsSep(r rune, rs []rune) bool {
for _, v := range rs {
if v == r {
return true
}
}
return false
}

func GetCSVSep(options Options) (rune, error) {
sep, err := funcutil.GetAttrByKeyFromRepeatedKV("sep", options)
unsupportedSep := []rune{0, '\n', '\r', '"'}
defaultSep := ','
if err != nil || len(sep) == 0 {
return defaultSep, nil
} else if containsSep([]rune(sep)[0], unsupportedSep) {
} else if lo.Contains(unsupportedSep, []rune(sep)[0]) {
return 0, merr.WrapErrImportFailed(fmt.Sprintf("unsupported csv seperator: %s", sep))
}
return []rune(sep)[0], nil
Expand Down

0 comments on commit a39bce4

Please sign in to comment.