diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 25eb970eb8893..3c997efd4e3bb 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -193,6 +193,7 @@ proxy: # As of today (2.2.0 and after) it is strongly DISCOURAGED to set maxFieldNum >= 64. # So adjust at your risk! maxFieldNum: 64 + maxVectorFieldNum: 4 # Maximum number of vector fields in a collection, (0, 10]. maxShardNum: 16 # Maximum number of shards in a collection maxDimension: 32768 # Maximum dimension of a vector # Whether to produce gin logs.\n diff --git a/internal/proxy/data_coord_mock_test.go b/internal/proxy/data_coord_mock_test.go index 745c3e07328fe..b0d587cdf5d88 100644 --- a/internal/proxy/data_coord_mock_test.go +++ b/internal/proxy/data_coord_mock_test.go @@ -268,6 +268,9 @@ func (coord *DataCoordMock) DropIndex(ctx context.Context, req *indexpb.DropInde } func (coord *DataCoordMock) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStateResponse, error) { + if coord.GetIndexStateFunc != nil { + return coord.GetIndexStateFunc(ctx, req, opts...) + } return &indexpb.GetIndexStateResponse{ Status: merr.Success(), State: commonpb.IndexState_Finished, @@ -291,6 +294,9 @@ func (coord *DataCoordMock) GetIndexInfos(ctx context.Context, req *indexpb.GetI // DescribeIndex describe the index info of the collection. func (coord *DataCoordMock) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) { + if coord.DescribeIndexFunc != nil { + return coord.DescribeIndexFunc(ctx, req, opts...) + } return &indexpb.DescribeIndexResponse{ Status: merr.Success(), IndexInfos: nil, diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 8feb1f3e7495a..bbe049fefac61 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -519,9 +519,11 @@ func TestProxy(t *testing.T) { shardsNum := common.DefaultShardsNum int64Field := "int64" floatVecField := "fVec" + binaryVecField := "bVec" dim := 128 rowNum := 3000 - indexName := "_default" + floatIndexName := "float_index" + binaryIndexName := "binary_index" nlist := 10 // nprobe := 10 // topk := 10 @@ -558,6 +560,21 @@ func TestProxy(t *testing.T) { IndexParams: nil, AutoID: false, } + bVec := &schemapb.FieldSchema{ + FieldID: 0, + Name: binaryVecField, + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + }, + IndexParams: nil, + AutoID: false, + } return &schemapb.CollectionSchema{ Name: collectionName, Description: "", @@ -565,6 +582,7 @@ func TestProxy(t *testing.T) { Fields: []*schemapb.FieldSchema{ pk, fVec, + bVec, }, } } @@ -585,13 +603,14 @@ func TestProxy(t *testing.T) { constructCollectionInsertRequest := func() *milvuspb.InsertRequest { fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) hashKeys := generateHashKeys(rowNum) return &milvuspb.InsertRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, PartitionName: "", - FieldsData: []*schemapb.FieldData{fVecColumn}, + FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn}, HashKeys: hashKeys, NumRows: uint32(rowNum), } @@ -599,13 +618,14 @@ func TestProxy(t *testing.T) { constructPartitionInsertRequest := func() *milvuspb.InsertRequest { fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) hashKeys := generateHashKeys(rowNum) return &milvuspb.InsertRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, PartitionName: partitionName, - FieldsData: []*schemapb.FieldData{fVecColumn}, + FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn}, HashKeys: hashKeys, NumRows: uint32(rowNum), } @@ -613,44 +633,75 @@ func TestProxy(t *testing.T) { constructCollectionUpsertRequest := func() *milvuspb.UpsertRequest { fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) hashKeys := generateHashKeys(rowNum) return &milvuspb.UpsertRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, PartitionName: partitionName, - FieldsData: []*schemapb.FieldData{fVecColumn}, + FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn}, HashKeys: hashKeys, NumRows: uint32(rowNum), } } - constructCreateIndexRequest := func() *milvuspb.CreateIndexRequest { - return &milvuspb.CreateIndexRequest{ + constructCreateIndexRequest := func(dataType schemapb.DataType) *milvuspb.CreateIndexRequest { + req := &milvuspb.CreateIndexRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, - FieldName: floatVecField, - IndexName: indexName, - ExtraParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: strconv.Itoa(dim), - }, - { - Key: common.MetricTypeKey, - Value: metric.L2, - }, - { - Key: common.IndexTypeKey, - Value: "IVF_FLAT", - }, - { - Key: "nlist", - Value: strconv.Itoa(nlist), - }, - }, } + switch dataType { + case schemapb.DataType_FloatVector: + { + req.FieldName = floatVecField + req.IndexName = floatIndexName + req.ExtraParams = []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + { + Key: common.MetricTypeKey, + Value: metric.L2, + }, + { + Key: common.IndexTypeKey, + Value: "IVF_FLAT", + }, + { + Key: "nlist", + Value: strconv.Itoa(nlist), + }, + } + } + case schemapb.DataType_BinaryVector: + { + req.FieldName = binaryVecField + req.IndexName = binaryIndexName + req.ExtraParams = []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + { + Key: common.MetricTypeKey, + Value: metric.JACCARD, + }, + { + Key: common.IndexTypeKey, + Value: "BIN_IVF_FLAT", + }, + { + Key: "nlist", + Value: strconv.Itoa(nlist), + }, + } + } + } + + return req } wg.Add(1) @@ -1098,9 +1149,9 @@ func TestProxy(t *testing.T) { }) wg.Add(1) - t.Run("create index", func(t *testing.T) { + t.Run("create index for floatVec field", func(t *testing.T) { defer wg.Done() - req := constructCreateIndexRequest() + req := constructCreateIndexRequest(schemapb.DataType_FloatVector) resp, err := proxy.CreateIndex(ctx, req) assert.NoError(t, err) @@ -1113,7 +1164,7 @@ func TestProxy(t *testing.T) { req := &milvuspb.AlterIndexRequest{ DbName: dbName, CollectionName: collectionName, - IndexName: indexName, + IndexName: floatIndexName, ExtraParams: []*commonpb.KeyValuePair{ { Key: common.MmapEnabledKey, @@ -1139,14 +1190,14 @@ func TestProxy(t *testing.T) { }) err = merr.CheckRPCCall(resp, err) assert.NoError(t, err) - assert.Equal(t, indexName, resp.IndexDescriptions[0].IndexName) + assert.Equal(t, floatIndexName, resp.IndexDescriptions[0].IndexName) assert.True(t, common.IsMmapEnabled(resp.IndexDescriptions[0].GetParams()...), "params: %+v", resp.IndexDescriptions[0]) // disable mmap then the tests below could continue req := &milvuspb.AlterIndexRequest{ DbName: dbName, CollectionName: collectionName, - IndexName: indexName, + IndexName: floatIndexName, ExtraParams: []*commonpb.KeyValuePair{ { Key: common.MmapEnabledKey, @@ -1170,7 +1221,7 @@ func TestProxy(t *testing.T) { }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) - indexName = resp.IndexDescriptions[0].IndexName + assert.Equal(t, floatIndexName, resp.IndexDescriptions[0].IndexName) }) wg.Add(1) @@ -1181,7 +1232,7 @@ func TestProxy(t *testing.T) { DbName: dbName, CollectionName: collectionName, FieldName: floatVecField, - IndexName: indexName, + IndexName: floatIndexName, }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) @@ -1195,12 +1246,44 @@ func TestProxy(t *testing.T) { DbName: dbName, CollectionName: collectionName, FieldName: floatVecField, - IndexName: indexName, + IndexName: floatIndexName, }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) + wg.Add(1) + t.Run("load collection not all vecFields with index", func(t *testing.T) { + defer wg.Done() + { + stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.GetStatus().GetErrorCode()) + assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State) + } + + resp, err := proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + }) + assert.NoError(t, err) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + }) + + wg.Add(1) + t.Run("create index for binVec field", func(t *testing.T) { + defer wg.Done() + req := constructCreateIndexRequest(schemapb.DataType_BinaryVector) + + resp, err := proxy.CreateIndex(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) + }) + loaded := true wg.Add(1) t.Run("load collection", func(t *testing.T) { @@ -2198,7 +2281,7 @@ func TestProxy(t *testing.T) { DbName: dbName, CollectionName: collectionName, FieldName: floatVecField, - IndexName: indexName, + IndexName: floatIndexName, }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) @@ -3448,6 +3531,21 @@ func TestProxy(t *testing.T) { IndexParams: nil, AutoID: false, } + bVec := &schemapb.FieldSchema{ + FieldID: 0, + Name: binaryVecField, + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + }, + IndexParams: nil, + AutoID: false, + } return &schemapb.CollectionSchema{ Name: collectionName, Description: "", @@ -3455,6 +3553,7 @@ func TestProxy(t *testing.T) { Fields: []*schemapb.FieldSchema{ pk, fVec, + bVec, }, } } @@ -3476,13 +3575,14 @@ func TestProxy(t *testing.T) { constructPartitionReqUpsertRequestValid := func() *milvuspb.UpsertRequest { pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum) fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) hashKeys := generateHashKeys(rowNum) return &milvuspb.UpsertRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, PartitionName: partitionName, - FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn}, + FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn, bVecColumn}, HashKeys: hashKeys, NumRows: uint32(rowNum), } @@ -3491,13 +3591,14 @@ func TestProxy(t *testing.T) { constructPartitionReqUpsertRequestInvalid := func() *milvuspb.UpsertRequest { pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum) fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) hashKeys := generateHashKeys(rowNum) return &milvuspb.UpsertRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, PartitionName: "%$@", - FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn}, + FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn, bVecColumn}, HashKeys: hashKeys, NumRows: uint32(rowNum), } @@ -3506,13 +3607,14 @@ func TestProxy(t *testing.T) { constructCollectionUpsertRequestValid := func() *milvuspb.UpsertRequest { pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum) fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) hashKeys := generateHashKeys(rowNum) return &milvuspb.UpsertRequest{ Base: nil, DbName: dbName, CollectionName: collectionName, PartitionName: partitionName, - FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn}, + FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn, bVecColumn}, HashKeys: hashKeys, NumRows: uint32(rowNum), } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index a1cfed1b8b668..8de4be8ccbb1c 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -226,6 +226,10 @@ func (t *createCollectionTask) PreExecute(ctx context.Context) error { return fmt.Errorf("maximum field's number should be limited to %d", Params.ProxyCfg.MaxFieldNum.GetAsInt()) } + if len(typeutil.GetVectorFieldSchemas(t.schema)) > Params.ProxyCfg.MaxVectorFieldNum.GetAsInt() { + return fmt.Errorf("maximum vector field's number should be limited to %d", Params.ProxyCfg.MaxVectorFieldNum.GetAsInt()) + } + // validate collection name if err := validateCollectionName(t.schema.Name); err != nil { return err @@ -1456,19 +1460,24 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) { return err } - hasVecIndex := false + // not support multiple indexes on one field fieldIndexIDs := make(map[int64]int64) for _, index := range indexResponse.IndexInfos { fieldIndexIDs[index.FieldID] = index.IndexID - for _, field := range collSchema.Fields { - if index.FieldID == field.FieldID && (field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_Float16Vector) { - hasVecIndex = true + } + + unindexedVecFields := make([]string, 0) + for _, field := range collSchema.GetFields() { + if isVectorType(field.GetDataType()) { + if _, ok := fieldIndexIDs[field.GetFieldID()]; !ok { + unindexedVecFields = append(unindexedVecFields, field.GetName()) } } } - if !hasVecIndex { - errMsg := fmt.Sprintf("there is no vector index on collection: %s, please create index firstly", t.LoadCollectionRequest.CollectionName) - log.Error(errMsg) + + if len(unindexedVecFields) != 0 { + errMsg := fmt.Sprintf("there is no vector index on field: %v, please create index firstly", unindexedVecFields) + log.Debug(errMsg) return errors.New(errMsg) } request := &querypb.LoadCollectionRequest{ diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index a4f0d075c79dc..9f96439610c47 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -294,14 +294,16 @@ func (t *searchTask) PreExecute(ctx context.Context) error { if t.request.GetDslType() == commonpb.DslType_BoolExprV1 { annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams()) if err != nil || len(annsField) == 0 { - if enableMultipleVectorFields { - return errors.New(AnnsFieldKey + " not found in search_params") - } - vecFieldSchema, err2 := typeutil.GetVectorFieldSchema(t.schema) - if err2 != nil { + vecFields := typeutil.GetVectorFieldSchemas(t.schema) + if len(vecFields) == 0 { return errors.New(AnnsFieldKey + " not found in schema") } - annsField = vecFieldSchema.Name + + if enableMultipleVectorFields && len(vecFields) > 1 { + return errors.New("multiple anns_fields exist, please specify a anns_field in search_params") + } + + annsField = vecFields[0].Name } queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams()) if err != nil { diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 877942f58c177..8427c0c68cc7f 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1700,6 +1700,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { }, CollectionName: collectionName, Nq: 2, + DslType: commonpb.DslType_BoolExprV1, }, qc: qc, lb: lb, @@ -1711,7 +1712,13 @@ func TestSearchTask_ErrExecute(t *testing.T) { assert.NoError(t, task.OnEnqueue()) task.ctx = ctx - assert.NoError(t, task.PreExecute(ctx)) + if enableMultipleVectorFields { + err = task.PreExecute(ctx) + assert.Error(t, err) + assert.Equal(t, err.Error(), "multiple anns_fields exist, please specify a anns_field in search_params") + } else { + assert.NoError(t, task.PreExecute(ctx)) + } qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) assert.Error(t, task.Execute(ctx)) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 445d8bd6c5a63..391bfce7491ba 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -682,10 +682,36 @@ func TestCreateCollectionTask(t *testing.T) { err = task.PreExecute(ctx) assert.Error(t, err) + // too many vector fields + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + schema.Fields = append(schema.Fields, schema.Fields[0]) + for i := 0; i < Params.ProxyCfg.MaxVectorFieldNum.GetAsInt(); i++ { + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 101, + Name: floatVecField + "_" + strconv.Itoa(i), + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(testVecDim), + }, + }, + IndexParams: nil, + AutoID: false, + }) + } + tooManyVectorFieldsSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = tooManyVectorFieldsSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + task.CreateCollectionRequest = reqBackup // validateCollectionName - + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) schema.Name = " " // empty emptyNameSchema, err := proto.Marshal(schema) assert.NoError(t, err) @@ -2551,6 +2577,42 @@ func Test_loadCollectionTask_Execute(t *testing.T) { err := lct.Execute(ctx) assert.Error(t, err) }) + + t.Run("not all vector fields with index", func(t *testing.T) { + vecFields := make([]*schemapb.FieldSchema, 0) + for _, field := range newTestSchema().GetFields() { + if isVectorType(field.GetDataType()) { + vecFields = append(vecFields, field) + } + } + + assert.GreaterOrEqual(t, len(vecFields), 2) + + dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) { + return &indexpb.DescribeIndexResponse{ + Status: merr.Success(), + IndexInfos: []*indexpb.IndexInfo{ + { + CollectionID: collectionID, + FieldID: vecFields[0].FieldID, + IndexName: indexName, + IndexID: indexID, + TypeParams: nil, + IndexParams: nil, + IndexedRows: 1025, + TotalRows: 1025, + State: commonpb.IndexState_Finished, + IndexStateFailReason: "", + IsAutoIndex: false, + UserIndexParams: nil, + }, + }, + }, nil + } + + err := lct.Execute(ctx) + assert.Error(t, err) + }) } func Test_loadPartitionTask_Execute(t *testing.T) { diff --git a/internal/proxy/util.go b/internal/proxy/util.go index e5fc96ba59633..e65117dad4025 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -56,7 +56,7 @@ const ( boundedTS = 2 // enableMultipleVectorFields indicates whether to enable multiple vector fields. - enableMultipleVectorFields = false + enableMultipleVectorFields = true defaultMaxVarCharLength = 65535 diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index ccdf7c26576b4..d5738e60f6534 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -952,6 +952,7 @@ type proxyConfig struct { MinPasswordLength ParamItem `refreshable:"true"` MaxPasswordLength ParamItem `refreshable:"true"` MaxFieldNum ParamItem `refreshable:"true"` + MaxVectorFieldNum ParamItem `refreshable:"true"` MaxShardNum ParamItem `refreshable:"true"` MaxDimension ParamItem `refreshable:"true"` GinLogging ParamItem `refreshable:"false"` @@ -1047,6 +1048,22 @@ So adjust at your risk!`, } p.MaxFieldNum.Init(base.mgr) + p.MaxVectorFieldNum = ParamItem{ + Key: "proxy.maxVectorFieldNum", + Version: "2.4.0", + DefaultValue: "4", + Formatter: func(v string) string { + if getAsInt(v) > 10 { + return "10" + } + return v + }, + PanicIfEmpty: true, + Doc: "Maximum number of vector fields in a collection.", + Export: true, + } + p.MaxVectorFieldNum.Init(base.mgr) + p.MaxShardNum = ParamItem{ Key: "proxy.maxShardNum", DefaultValue: "16", diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index 36efea081736f..fabede94d33d6 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -139,6 +139,8 @@ func TestComponentParam(t *testing.T) { t.Logf("MaxFieldNum: %d", Params.MaxFieldNum.GetAsInt64()) + t.Logf("MaxVectorFieldNum: %d", Params.MaxVectorFieldNum.GetAsInt64()) + t.Logf("MaxShardNum: %d", Params.MaxShardNum.GetAsInt64()) t.Logf("MaxDimension: %d", Params.MaxDimension.GetAsInt64()) diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 56dd6600f34e9..ad1d5df3b55fb 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -782,6 +782,18 @@ func GetVectorFieldSchema(schema *schemapb.CollectionSchema) (*schemapb.FieldSch return nil, errors.New("vector field is not found") } +// GetVectorFieldSchemas get vector fields schema from collection schema. +func GetVectorFieldSchemas(schema *schemapb.CollectionSchema) []*schemapb.FieldSchema { + ret := make([]*schemapb.FieldSchema, 0) + for _, fieldSchema := range schema.Fields { + if IsVectorType(fieldSchema.DataType) { + ret = append(ret, fieldSchema) + } + } + + return ret +} + // GetPrimaryFieldSchema get primary field schema from collection schema func GetPrimaryFieldSchema(schema *schemapb.CollectionSchema) (*schemapb.FieldSchema, error) { for _, fieldSchema := range schema.Fields { diff --git a/tests/python_client/testcases/test_collection.py b/tests/python_client/testcases/test_collection.py index fe21b9525ccff..29086cc1f6c80 100644 --- a/tests/python_client/testcases/test_collection.py +++ b/tests/python_client/testcases/test_collection.py @@ -456,7 +456,7 @@ def test_collection_multi_float_vectors(self): """ target: test collection with multi float vectors method: create collection with two float-vec fields - expected: raise exception (not supported yet) + expected: Collection created successfully """ # 1. connect self._connect() @@ -465,25 +465,24 @@ def test_collection_multi_float_vectors(self): fields = [cf.gen_int64_field(is_primary=True), cf.gen_float_field(), cf.gen_float_vec_field(dim=default_dim), cf.gen_float_vec_field(name="tmp", dim=default_dim)] schema = cf.gen_collection_schema(fields=fields) - err_msg = "multiple vector fields is not supported" self.collection_wrap.init_collection(c_name, schema=schema, - check_task=CheckTasks.err_res, - check_items={"err_code": 1, "err_msg": err_msg}) + check_task=CheckTasks.check_collection_property, + check_items={exp_name: c_name, exp_schema: schema}) @pytest.mark.tags(CaseLabel.L1) def test_collection_mix_vectors(self): """ target: test collection with mix vectors method: create with float and binary vec - expected: raise exception + expected: Collection created successfully """ self._connect() c_name = cf.gen_unique_str(prefix) fields = [cf.gen_int64_field(is_primary=True), cf.gen_float_vec_field(), cf.gen_binary_vec_field()] schema = cf.gen_collection_schema(fields=fields, auto_id=True) - err_msg = "multiple vector fields is not supported" - self.collection_wrap.init_collection(c_name, schema=schema, check_task=CheckTasks.err_res, - check_items={"err_code": 1, "err_msg": err_msg}) + self.collection_wrap.init_collection(c_name, schema=schema, + check_task=CheckTasks.check_collection_property, + check_items={exp_name: c_name, exp_schema: schema}) @pytest.mark.tags(CaseLabel.L0) def test_collection_without_vectors(self): diff --git a/tests/python_client/testcases/test_index.py b/tests/python_client/testcases/test_index.py index 2723beb3b6f3a..c951c2300a705 100644 --- a/tests/python_client/testcases/test_index.py +++ b/tests/python_client/testcases/test_index.py @@ -244,8 +244,8 @@ def test_index_create_on_scalar_field(self): collection_w.create_index(ct.default_int64_field_name, {}) collection_w.load(check_task=CheckTasks.err_res, check_items={ct.err_code: 65535, - ct.err_msg: f"there is no vector index on collection: {collection_w.name}, " - f"please create index firstly"}) + ct.err_msg: f"there is no vector index on field: [float_vector], " + f"please create index firstly: collection={collection_w.name}: index not found"}) @pytest.mark.tags(CaseLabel.L2) def test_index_create_on_array_field(self):