Skip to content

Commit

Permalink
enhance: Remove duplicated schema helper creation in proxy
Browse files Browse the repository at this point in the history
Related to PRs of milvus-io#35415

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia committed Aug 15, 2024
1 parent c6ae7d4 commit d3693db
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 91 deletions.
20 changes: 3 additions & 17 deletions internal/proxy/task_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,7 @@ func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.Fiel
log.Error("failed to get collection schema", zap.Error(err))
return nil, fmt.Errorf("failed to get collection schema: %s", err)
}
schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema)
if err != nil {
log.Error("failed to parse collection schema", zap.Error(err))
return nil, fmt.Errorf("failed to parse collection schema: %s", err)
}
field, err := schemaHelper.GetFieldFromName(cit.req.GetFieldName())
field, err := schema.schemaHelper.GetFieldFromName(cit.req.GetFieldName())
if err != nil {
log.Error("create index on non-exist field", zap.Error(err))
return nil, fmt.Errorf("cannot create index on non-exist field: %s", cit.req.GetFieldName())
Expand Down Expand Up @@ -678,11 +673,6 @@ func (dit *describeIndexTask) Execute(ctx context.Context) error {
log.Error("failed to get collection schema", zap.Error(err))
return fmt.Errorf("failed to get collection schema: %s", err)
}
schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema)
if err != nil {
log.Error("failed to parse collection schema", zap.Error(err))
return fmt.Errorf("failed to parse collection schema: %s", err)
}

resp, err := dit.datacoord.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{CollectionID: dit.collectionID, IndexName: dit.IndexName, Timestamp: dit.Timestamp})
if err != nil {
Expand All @@ -700,7 +690,7 @@ func (dit *describeIndexTask) Execute(ctx context.Context) error {
return err
}
for _, indexInfo := range resp.IndexInfos {
field, err := schemaHelper.GetFieldFromID(indexInfo.FieldID)
field, err := schema.schemaHelper.GetFieldFromID(indexInfo.FieldID)
if err != nil {
log.Error("failed to get collection field", zap.Error(err))
return fmt.Errorf("failed to get collection field: %d", indexInfo.FieldID)
Expand Down Expand Up @@ -802,11 +792,7 @@ func (dit *getIndexStatisticsTask) Execute(ctx context.Context) error {
log.Error("failed to get collection schema", zap.String("collection_name", dit.GetCollectionName()), zap.Error(err))
return fmt.Errorf("failed to get collection schema: %s", dit.GetCollectionName())
}
schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema)
if err != nil {
log.Error("failed to parse collection schema", zap.String("collection_name", schema.GetName()), zap.Error(err))
return fmt.Errorf("failed to parse collection schema: %s", dit.GetCollectionName())
}
schemaHelper := schema.schemaHelper

resp, err := dit.datacoord.GetIndexStatistics(ctx, &indexpb.GetIndexStatisticsRequest{
CollectionID: dit.collectionID, IndexName: dit.IndexName,
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/task_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
}

if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck(), withMaxCapCheck()).
Validate(it.insertMsg.GetFieldsData(), schema.CollectionSchema, it.insertMsg.NRows()); err != nil {
Validate(it.insertMsg.GetFieldsData(), schema.schemaHelper, it.insertMsg.NRows()); err != nil {
return merr.WrapErrAsInputError(err)
}

Expand Down
73 changes: 35 additions & 38 deletions internal/proxy/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,35 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
},
}

idField := &schemapb.FieldSchema{
FieldID: 100,
Name: "id",
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: nil,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
AutoID: false,
}
vectorField := &schemapb.FieldSchema{
FieldID: 101,
Name: fieldName,
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: nil,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
AutoID: false,
}

t.Run("normal", func(t *testing.T) {
cache := NewMockCache(t)
cache.On("GetCollectionSchema",
Expand All @@ -2096,20 +2125,8 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
mock.AnythingOfType("string"),
).Return(newSchemaInfo(&schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: fieldName,
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: nil,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
AutoID: false,
},
idField,
vectorField,
},
}), nil)

Expand All @@ -2131,38 +2148,18 @@ func Test_createIndexTask_getIndexedField(t *testing.T) {
assert.Error(t, err)
})

t.Run("invalid schema", func(t *testing.T) {
cache := NewMockCache(t)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(newSchemaInfo(&schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: fieldName,
},
{
Name: fieldName, // duplicate
},
},
}), nil)
globalMetaCache = cache
_, err := cit.getIndexedField(context.Background())
assert.Error(t, err)
})

t.Run("field not found", func(t *testing.T) {
otherField := typeutil.Clone(vectorField)
otherField.Name = otherField.Name + "_other"
cache := NewMockCache(t)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(newSchemaInfo(&schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
Name: fieldName + fieldName,
},
idField,
otherField,
},
}), nil)
globalMetaCache = cache
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/task_upsert.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
}

if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()).
Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.CollectionSchema, it.upsertMsg.InsertMsg.NRows()); err != nil {
Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.schemaHelper, it.upsertMsg.InsertMsg.NRows()); err != nil {
return err
}

Expand Down
6 changes: 1 addition & 5 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -1012,11 +1012,7 @@ func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary
userOutputFieldsMap[outputFieldName] = true
} else {
if schema.EnableDynamicField {
schemaH, err := typeutil.CreateSchemaHelper(schema.CollectionSchema)
if err != nil {
return nil, nil, err
}
err = planparserv2.ParseIdentifier(schemaH, outputFieldName, func(expr *planpb.Expr) error {
err := planparserv2.ParseIdentifier(schema.schemaHelper, outputFieldName, func(expr *planpb.Expr) error {
if len(expr.GetColumnExpr().GetInfo().GetNestedPath()) == 1 &&
expr.GetColumnExpr().GetInfo().GetNestedPath()[0] == outputFieldName {
return nil
Expand Down
10 changes: 4 additions & 6 deletions internal/proxy/validate_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,10 @@ func (v *validateUtil) apply(opts ...validateOption) {
}
}

func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.CollectionSchema, numRows uint64) error {
helper, err := typeutil.CreateSchemaHelper(schema)
if err != nil {
return err
func (v *validateUtil) Validate(data []*schemapb.FieldData, helper *typeutil.SchemaHelper, numRows uint64) error {
if helper == nil {
return merr.WrapErrServiceInternal("nil schema helper provided for Validation")
}

for _, field := range data {
fieldSchema, err := helper.GetFieldFromName(field.GetFieldName())
if err != nil {
Expand Down Expand Up @@ -122,7 +120,7 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col
}
}

err = v.fillWithValue(data, helper, int(numRows))
err := v.fillWithValue(data, helper, int(numRows))
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit d3693db

Please sign in to comment.