From b959cced3507ff0092da92303a3025194d113f61 Mon Sep 17 00:00:00 2001 From: Congqi Xia Date: Sun, 11 Aug 2024 21:54:32 +0800 Subject: [PATCH] feat: Support horizontal partial load collection Related to #35415 Signed-off-by: Congqi Xia --- internal/proto/query_coord.proto | 5 + internal/proxy/impl.go | 2 +- internal/proxy/meta_cache.go | 100 +++++++++++++++++- internal/proxy/task.go | 15 +++ internal/proxy/util.go | 17 +-- internal/querycoordv2/job/job_load.go | 2 + internal/querycoordv2/job/job_test.go | 6 ++ .../querycoordv2/meta/collection_manager.go | 11 ++ .../observers/collection_observer.go | 23 ++++ .../observers/collection_observer_test.go | 18 ++-- internal/querycoordv2/ops_service_test.go | 5 + internal/querycoordv2/server.go | 1 + internal/querycoordv2/server_test.go | 1 + internal/querycoordv2/services.go | 5 + internal/querycoordv2/services_test.go | 7 ++ internal/querycoordv2/task/executor.go | 4 + internal/querycoordv2/task/utils.go | 3 +- internal/querynodev2/segments/collection.go | 20 +++- .../querynodev2/segments/segment_loader.go | 10 +- pkg/common/common.go | 27 +++++ pkg/common/common_test.go | 28 +++++ pkg/util/typeutil/schema.go | 26 ++++- 22 files changed, 312 insertions(+), 24 deletions(-) diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index ed0e152bbd708..75cfe785992fc 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -187,6 +187,7 @@ message ShowCollectionsResponse { repeated int64 inMemory_percentages = 3; repeated bool query_service_available = 4; repeated int64 refresh_progress = 5; + repeated schema.LongArray load_fields = 6; } message ShowPartitionsRequest { @@ -214,6 +215,7 @@ message LoadCollectionRequest { bool refresh = 7; // resource group names repeated string resource_groups = 8; + repeated int64 load_fields = 9; } message ReleaseCollectionRequest { @@ -244,6 +246,7 @@ message LoadPartitionsRequest { // resource group names repeated string resource_groups = 9; repeated index.IndexInfo index_info_list = 10; + repeated int64 load_fields = 11; } message ReleasePartitionsRequest { @@ -313,6 +316,7 @@ message LoadMetaInfo { string metric_type = 4 [deprecated = true]; string db_name = 5; // Only used for metrics label. string resource_group = 6; // Only used for metrics label. + repeated int64 load_fields = 7; } message WatchDmChannelsRequest { @@ -650,6 +654,7 @@ message CollectionLoadInfo { map field_indexID = 5; LoadType load_type = 6; int32 recover_times = 7; + repeated int64 load_fields = 8; } message PartitionLoadInfo { diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 1fb0785b40630..20431e5741b05 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -123,7 +123,7 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p if globalMetaCache != nil { switch msgType { - case commonpb.MsgType_DropCollection, commonpb.MsgType_RenameCollection, commonpb.MsgType_DropAlias, commonpb.MsgType_AlterAlias: + case commonpb.MsgType_DropCollection, commonpb.MsgType_RenameCollection, commonpb.MsgType_DropAlias, commonpb.MsgType_AlterAlias, commonpb.MsgType_LoadCollection: if collectionName != "" { globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName) // no need to return error, though collection may be not cached globalMetaCache.DeprecateShardCache(request.GetDbName(), collectionName) diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index e74940c4d7e9e..576292b68ed3d 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -128,7 +128,7 @@ type schemaInfo struct { schemaHelper *typeutil.SchemaHelper } -func newSchemaInfo(schema *schemapb.CollectionSchema) *schemaInfo { +func newSchemaInfoWithLoadFields(schema *schemapb.CollectionSchema, loadFields []int64) *schemaInfo { fieldMap := typeutil.NewConcurrentMap[string, int64]() hasPartitionkey := false var pkField *schemapb.FieldSchema @@ -142,16 +142,21 @@ func newSchemaInfo(schema *schemapb.CollectionSchema) *schemaInfo { } } // schema shall be verified before - schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + schemaHelper, _ := typeutil.CreateSchemaHelperWithLoadFields(schema, loadFields) return &schemaInfo{ CollectionSchema: schema, fieldMap: fieldMap, hasPartitionKeyField: hasPartitionkey, pkField: pkField, schemaHelper: schemaHelper, + // loadFields: typeutil.NewSet(loadFields...), } } +func newSchemaInfo(schema *schemapb.CollectionSchema) *schemaInfo { + return newSchemaInfoWithLoadFields(schema, nil) +} + func (s *schemaInfo) MapFieldID(name string) (int64, bool) { return s.fieldMap.Get(name) } @@ -167,6 +172,68 @@ func (s *schemaInfo) GetPkField() (*schemapb.FieldSchema, error) { return s.pkField, nil } +// GetLoadFieldIDs returns field id for load field list. +// If input `loadFields` is empty, use collection schema definition. +// Otherwise, perform load field list constraint check then return field id. +func (s *schemaInfo) GetLoadFieldIDs(loadFields []string) ([]int64, error) { + if len(loadFields) == 0 { + // skip check logic since create collection already did the rule check already + return common.GetCollectionLoadFields(s.CollectionSchema), nil + } + + fieldIDs := make([]int64, 0, len(loadFields)) + fields := make([]*schemapb.FieldSchema, 0, len(loadFields)) + for _, name := range loadFields { + fieldSchema, err := s.schemaHelper.GetFieldFromName(name) + if err != nil { + return nil, err + } + + fields = append(fields, fieldSchema) + fieldIDs = append(fieldIDs, fieldSchema.GetFieldID()) + } + + // validate load fields list + if err := s.validateLoadFields(loadFields, fields); err != nil { + return nil, err + } + + return fieldIDs, nil +} + +func (s *schemaInfo) validateLoadFields(names []string, fields []*schemapb.FieldSchema) error { + // ignore error if not found + partitionKeyField, _ := s.schemaHelper.GetPrimaryKeyField() + + var hasPrimaryKey, hasPartitionKey, hasVector bool + for _, field := range fields { + if field.GetFieldID() == s.pkField.GetFieldID() { + hasPrimaryKey = true + } + if typeutil.IsVectorType(field.GetDataType()) { + hasVector = true + } + if field.IsPartitionKey { + hasPartitionKey = true + } + } + + if !hasPrimaryKey { + return merr.WrapErrParameterInvalidMsg("load field list %v does not contain primary key field %s", names, s.pkField.GetName()) + } + if !hasVector { + return merr.WrapErrParameterInvalidMsg("load field list %v does not contain vector field", names) + } + if partitionKeyField != nil && !hasPartitionKey { + return merr.WrapErrParameterInvalidMsg("load field list %v does not contain partition key field %s", names, partitionKeyField.GetName()) + } + return nil +} + +func (s *schemaInfo) IsFieldLoaded(fieldID int64) bool { + return s.schemaHelper.IsFieldLoaded(fieldID) +} + // partitionInfos contains the cached collection partition informations. type partitionInfos struct { partitionInfos []*partitionInfo @@ -366,6 +433,11 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string, return nil, err } + loadFields, err := m.getCollectionLoadFields(ctx, collection.CollectionID) + if err != nil { + return nil, err + } + // check partitionID, createdTimestamp and utcstamp has sam element numbers if len(partitions.PartitionNames) != len(partitions.CreatedTimestamps) || len(partitions.PartitionNames) != len(partitions.CreatedUtcTimestamps) { return nil, merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String()) @@ -393,7 +465,7 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string, return nil, err } - schemaInfo := newSchemaInfo(collection.Schema) + schemaInfo := newSchemaInfoWithLoadFields(collection.Schema, loadFields) m.collInfo[database][collectionName] = &collectionInfo{ collID: collection.CollectionID, schema: schemaInfo, @@ -760,6 +832,28 @@ func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectio return partitions, nil } +func (m *MetaCache) getCollectionLoadFields(ctx context.Context, collectionID UniqueID) ([]int64, error) { + req := &querypb.ShowCollectionsRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + CollectionIDs: []int64{collectionID}, + } + + resp, err := m.queryCoord.ShowCollections(ctx, req) + if err != nil { + if errors.Is(err, merr.ErrCollectionNotLoaded) { + return []int64{}, nil + } + return nil, err + } + // backward compatility, ignore HPL logic + if len(resp.GetLoadFields()) < 1 { + return []int64{}, nil + } + return resp.GetLoadFields()[0].GetData(), nil +} + func (m *MetaCache) describeDatabase(ctx context.Context, dbName string) (*rootcoordpb.DescribeDatabaseResponse, error) { req := &rootcoordpb.DescribeDatabaseRequest{ DbName: dbName, diff --git a/internal/proxy/task.go b/internal/proxy/task.go index fc01844170408..ed74ca8bdadcd 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1611,6 +1611,13 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) { if err != nil { return err } + // prepare load field list + // TODO use load collection load field list after proto merged + loadFields, err := collSchema.GetLoadFieldIDs(nil) + if err != nil { + return err + } + // check index indexResponse, err := t.datacoord.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{ CollectionID: collID, @@ -1658,6 +1665,7 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) { FieldIndexID: fieldIndexIDs, Refresh: t.Refresh, ResourceGroups: t.ResourceGroups, + LoadFields: loadFields, } log.Debug("send LoadCollectionRequest to query coordinator", zap.Any("schema", request.Schema)) @@ -1855,6 +1863,12 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error { if err != nil { return err } + // prepare load field list + // TODO use load collection load field list after proto merged + loadFields, err := collSchema.GetLoadFieldIDs(nil) + if err != nil { + return err + } // check index indexResponse, err := t.datacoord.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{ CollectionID: collID, @@ -1908,6 +1922,7 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error { FieldIndexID: fieldIndexIDs, Refresh: t.Refresh, ResourceGroups: t.ResourceGroups, + LoadFields: loadFields, } t.result, err = t.queryCoord.LoadPartitions(ctx, request) if err = merr.CheckRPCCall(t.result, err); err != nil { diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 99438abb4e597..debd8869a9b09 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -986,7 +986,7 @@ func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int // output_fields=["*",C] ==> [A,B,C,D] func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary bool) ([]string, []string, error) { var primaryFieldName string - allFieldNameMap := make(map[string]bool) + allFieldNameMap := make(map[string]int64) resultFieldNameMap := make(map[string]bool) resultFieldNames := make([]string, 0) userOutputFieldsMap := make(map[string]bool) @@ -996,18 +996,21 @@ func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary if field.IsPrimaryKey { primaryFieldName = field.Name } - allFieldNameMap[field.Name] = true + allFieldNameMap[field.Name] = field.GetFieldID() } for _, outputFieldName := range outputFields { outputFieldName = strings.TrimSpace(outputFieldName) if outputFieldName == "*" { - for fieldName := range allFieldNameMap { - resultFieldNameMap[fieldName] = true - userOutputFieldsMap[fieldName] = true + for fieldName, fieldID := range allFieldNameMap { + // skip Cold field + if schema.IsFieldLoaded(fieldID) { + resultFieldNameMap[fieldName] = true + userOutputFieldsMap[fieldName] = true + } } } else { - if _, ok := allFieldNameMap[outputFieldName]; ok { + if fieldID, ok := allFieldNameMap[outputFieldName]; ok && schema.IsFieldLoaded(fieldID) { resultFieldNameMap[outputFieldName] = true userOutputFieldsMap[outputFieldName] = true } else { @@ -1026,7 +1029,7 @@ func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary resultFieldNameMap[common.MetaFieldName] = true userOutputFieldsMap[outputFieldName] = true } else { - return nil, nil, fmt.Errorf("field %s not exist", outputFieldName) + return nil, nil, fmt.Errorf("field %s not exist or not loaded", outputFieldName) } } } diff --git a/internal/querycoordv2/job/job_load.go b/internal/querycoordv2/job/job_load.go index 234ab219e90ea..4ade22ee48c73 100644 --- a/internal/querycoordv2/job/job_load.go +++ b/internal/querycoordv2/job/job_load.go @@ -191,6 +191,7 @@ func (job *LoadCollectionJob) Execute() error { Status: querypb.LoadStatus_Loading, FieldIndexID: req.GetFieldIndexID(), LoadType: querypb.LoadType_LoadCollection, + LoadFields: req.GetLoadFields(), }, CreatedAt: time.Now(), LoadSpan: sp, @@ -371,6 +372,7 @@ func (job *LoadPartitionJob) Execute() error { Status: querypb.LoadStatus_Loading, FieldIndexID: req.GetFieldIndexID(), LoadType: querypb.LoadType_LoadPartition, + LoadFields: req.GetLoadFields(), }, CreatedAt: time.Now(), LoadSpan: sp, diff --git a/internal/querycoordv2/job/job_test.go b/internal/querycoordv2/job/job_test.go index ad07022f6fb5e..14cd505770a72 100644 --- a/internal/querycoordv2/job/job_test.go +++ b/internal/querycoordv2/job/job_test.go @@ -38,6 +38,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/observers" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/merr" @@ -71,6 +72,7 @@ type JobSuite struct { broker *meta.MockBroker nodeMgr *session.NodeManager checkerController *checkers.CheckerController + proxyManager *proxyutil.MockProxyClientManager // Test objects scheduler *Scheduler @@ -140,6 +142,9 @@ func (suite *JobSuite) SetupSuite() { suite.cluster.EXPECT(). ReleasePartitions(mock.Anything, mock.Anything, mock.Anything). Return(merr.Success(), nil).Maybe() + + suite.proxyManager = proxyutil.NewMockProxyClientManager(suite.T()) + suite.proxyManager.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() } func (suite *JobSuite) SetupTest() { @@ -199,6 +204,7 @@ func (suite *JobSuite) SetupTest() { suite.targetMgr, suite.targetObserver, suite.checkerController, + suite.proxyManager, ) } diff --git a/internal/querycoordv2/meta/collection_manager.go b/internal/querycoordv2/meta/collection_manager.go index f7fa1a5685a60..7071556460d3a 100644 --- a/internal/querycoordv2/meta/collection_manager.go +++ b/internal/querycoordv2/meta/collection_manager.go @@ -356,6 +356,17 @@ func (m *CollectionManager) GetFieldIndex(collectionID typeutil.UniqueID) map[in return nil } +func (m *CollectionManager) GetLoadFields(collectionID typeutil.UniqueID) []int64 { + m.rwmutex.RLock() + defer m.rwmutex.RUnlock() + + collection, ok := m.collections[collectionID] + if ok { + return collection.GetLoadFields() + } + return nil +} + func (m *CollectionManager) Exist(collectionID typeutil.UniqueID) bool { m.rwmutex.RLock() defer m.rwmutex.RUnlock() diff --git a/internal/querycoordv2/observers/collection_observer.go b/internal/querycoordv2/observers/collection_observer.go index 6a935a6d521f4..69235f783c445 100644 --- a/internal/querycoordv2/observers/collection_observer.go +++ b/internal/querycoordv2/observers/collection_observer.go @@ -26,14 +26,18 @@ import ( "go.opentelemetry.io/otel/trace" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/checkers" "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/eventlog" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -50,6 +54,8 @@ type CollectionObserver struct { loadTasks *typeutil.ConcurrentMap[string, LoadTask] + proxyManager proxyutil.ProxyClientManagerInterface + stopOnce sync.Once } @@ -65,6 +71,7 @@ func NewCollectionObserver( targetMgr meta.TargetManagerInterface, targetObserver *TargetObserver, checherController *checkers.CheckerController, + proxyManager proxyutil.ProxyClientManagerInterface, ) *CollectionObserver { ob := &CollectionObserver{ dist: dist, @@ -74,6 +81,7 @@ func NewCollectionObserver( checkerController: checherController, partitionLoadedCount: make(map[int64]int), loadTasks: typeutil.NewConcurrentMap[string, LoadTask](), + proxyManager: proxyManager, } // Add load task for collection recovery @@ -347,5 +355,20 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa zap.Int32("partitionLoadPercentage", loadPercentage), zap.Int32("collectionLoadPercentage", collectionPercentage), ) + if collectionPercentage == 100 { + ob.invalidateCache(ctx, partition.GetCollectionID()) + } eventlog.Record(eventlog.NewRawEvt(eventlog.Level_Info, fmt.Sprintf("collection %d load percentage update: %d", partition.CollectionID, loadPercentage))) } + +func (ob *CollectionObserver) invalidateCache(ctx context.Context, collectionID int64) { + ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Second)) + defer cancel() + err := ob.proxyManager.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{ + CollectionID: collectionID, + }, proxyutil.SetMsgType(commonpb.MsgType_LoadCollection)) + if err != nil { + log.Warn("failed to invalidate proxy's shard leader cache", zap.Error(err)) + return + } +} diff --git a/internal/querycoordv2/observers/collection_observer_test.go b/internal/querycoordv2/observers/collection_observer_test.go index 6e8d4f541d77b..6f26a924525c4 100644 --- a/internal/querycoordv2/observers/collection_observer_test.go +++ b/internal/querycoordv2/observers/collection_observer_test.go @@ -35,6 +35,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" @@ -55,12 +56,13 @@ type CollectionObserverSuite struct { nodes []int64 // Mocks - idAllocator func() (int64, error) - etcd *clientv3.Client - kv kv.MetaKv - store metastore.QueryCoordCatalog - broker *meta.MockBroker - cluster *session.MockCluster + idAllocator func() (int64, error) + etcd *clientv3.Client + kv kv.MetaKv + store metastore.QueryCoordCatalog + broker *meta.MockBroker + cluster *session.MockCluster + proxyManager *proxyutil.MockProxyClientManager // Dependencies dist *meta.DistributionManager @@ -162,6 +164,9 @@ func (suite *CollectionObserverSuite) SetupSuite() { 103: 2, } suite.nodes = []int64{1, 2, 3} + + suite.proxyManager = proxyutil.NewMockProxyClientManager(suite.T()) + suite.proxyManager.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() } func (suite *CollectionObserverSuite) SetupTest() { @@ -209,6 +214,7 @@ func (suite *CollectionObserverSuite) SetupTest() { suite.targetMgr, suite.targetObserver, suite.checkerController, + suite.proxyManager, ) for _, collection := range suite.collections { diff --git a/internal/querycoordv2/ops_service_test.go b/internal/querycoordv2/ops_service_test.go index c073bdf0f5fd7..02bb83fbd959c 100644 --- a/internal/querycoordv2/ops_service_test.go +++ b/internal/querycoordv2/ops_service_test.go @@ -41,6 +41,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/util/etcd" @@ -66,6 +67,7 @@ type OpsServiceSuite struct { jobScheduler *job.Scheduler taskScheduler *task.MockScheduler balancer balance.Balance + proxyManager *proxyutil.MockProxyClientManager distMgr *meta.DistributionManager distController *dist.MockController @@ -77,6 +79,8 @@ type OpsServiceSuite struct { func (suite *OpsServiceSuite) SetupSuite() { paramtable.Init() + suite.proxyManager = proxyutil.NewMockProxyClientManager(suite.T()) + suite.proxyManager.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() } func (suite *OpsServiceSuite) SetupTest() { @@ -151,6 +155,7 @@ func (suite *OpsServiceSuite) SetupTest() { suite.server.targetMgr, suite.targetObserver, &checkers.CheckerController{}, + suite.proxyManager, ) suite.server.UpdateStateCode(commonpb.StateCode_Healthy) diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index d2c997e339c88..35398beb9c1f0 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -410,6 +410,7 @@ func (s *Server) initObserver() { s.targetMgr, s.targetObserver, s.checkerController, + s.proxyClientManager, ) s.replicaObserver = observers.NewReplicaObserver( diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index 78c2fdb89b6f1..ca3a84b83965e 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -587,6 +587,7 @@ func (suite *ServerSuite) hackServer() { suite.server.targetMgr, suite.server.targetObserver, suite.server.checkerController, + suite.server.proxyClientManager, ) suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Schema: &schemapb.CollectionSchema{}}, nil).Maybe() diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index 6b3f4c43d1539..3e5198f3051d1 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/job" @@ -86,6 +87,7 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio collection := s.meta.CollectionManager.GetCollection(collectionID) percentage := s.meta.CollectionManager.CalculateLoadPercentage(collectionID) + loadFields := s.meta.CollectionManager.GetLoadFields(collectionID) refreshProgress := int64(0) if percentage < 0 { if isGetAll { @@ -118,6 +120,9 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio resp.InMemoryPercentages = append(resp.InMemoryPercentages, int64(percentage)) resp.QueryServiceAvailable = append(resp.QueryServiceAvailable, s.checkAnyReplicaAvailable(collectionID)) resp.RefreshProgress = append(resp.RefreshProgress, refreshProgress) + resp.LoadFields = append(resp.LoadFields, &schemapb.LongArray{ + Data: loadFields, + }) } return resp, nil diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index 75348e7c832f4..18dd5c408b429 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -47,6 +47,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/util/etcd" @@ -86,6 +87,8 @@ type ServiceSuite struct { distMgr *meta.DistributionManager distController *dist.MockController + proxyManager *proxyutil.MockProxyClientManager + // Test object server *Server } @@ -124,6 +127,9 @@ func (suite *ServiceSuite) SetupSuite() { 1, 2, 3, 4, 5, 101, 102, 103, 104, 105, } + + suite.proxyManager = proxyutil.NewMockProxyClientManager(suite.T()) + suite.proxyManager.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() } func (suite *ServiceSuite) SetupTest() { @@ -185,6 +191,7 @@ func (suite *ServiceSuite) SetupTest() { suite.targetMgr, suite.targetObserver, &checkers.CheckerController{}, + suite.proxyManager, ) suite.collectionObserver.Start() diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index 1a2e3e6edd7d1..3c7941e49751c 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -343,6 +343,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error { log.Warn("failed to get collection info") return err } + loadFields := ex.meta.GetLoadFields(task.CollectionID()) partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID()) if err != nil { log.Warn("failed to get partitions of collection") @@ -358,6 +359,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error { task.CollectionID(), collectionInfo.GetDbName(), task.ResourceGroup(), + loadFields, partitions..., ) @@ -649,6 +651,7 @@ func (ex *Executor) getMetaInfo(ctx context.Context, task Task) (*milvuspb.Descr log.Warn("failed to get collection info", zap.Error(err)) return nil, nil, nil, err } + loadFields := ex.meta.GetLoadFields(task.CollectionID()) partitions, err := utils.GetPartitions(ex.meta.CollectionManager, collectionID) if err != nil { log.Warn("failed to get partitions of collection", zap.Error(err)) @@ -660,6 +663,7 @@ func (ex *Executor) getMetaInfo(ctx context.Context, task Task) (*milvuspb.Descr task.CollectionID(), collectionInfo.GetDbName(), task.ResourceGroup(), + loadFields, partitions..., ) diff --git a/internal/querycoordv2/task/utils.go b/internal/querycoordv2/task/utils.go index 6bf9a289ce0d8..7536e1ab10f3c 100644 --- a/internal/querycoordv2/task/utils.go +++ b/internal/querycoordv2/task/utils.go @@ -182,13 +182,14 @@ func packReleaseSegmentRequest(task *SegmentTask, action *SegmentAction) *queryp } } -func packLoadMeta(loadType querypb.LoadType, collectionID int64, databaseName string, resourceGroup string, partitions ...int64) *querypb.LoadMetaInfo { +func packLoadMeta(loadType querypb.LoadType, collectionID int64, databaseName string, resourceGroup string, loadFields []int64, partitions ...int64) *querypb.LoadMetaInfo { return &querypb.LoadMetaInfo{ LoadType: loadType, CollectionID: collectionID, PartitionIDs: partitions, DbName: databaseName, ResourceGroup: resourceGroup, + LoadFields: loadFields, } } diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index 6baf1bd4bb94b..86c85e0f80ea0 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -146,6 +146,7 @@ type Collection struct { metricType atomic.String // deprecated schema atomic.Pointer[schemapb.CollectionSchema] isGpuIndex bool + loadFields typeutil.Set[int64] refCount *atomic.Uint32 } @@ -227,7 +228,23 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM CCollection NewCollection(const char* schema_proto_blob); */ - schemaBlob, err := proto.Marshal(schema) + + var loadFieldIDs typeutil.Set[int64] + loadSchema := typeutil.Clone(schema) + + // if load fields is specified, do filtering logic + // otherwise use all fields for backward compatibility + if len(loadMetaInfo.GetLoadFields()) > 0 { + loadFieldIDs = typeutil.NewSet(loadMetaInfo.GetLoadFields()...) + loadSchema.Fields = lo.Filter(loadSchema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool { + // system field shall always be loaded for now + return loadFieldIDs.Contain(field.GetFieldID()) || common.IsSystemField(field.GetFieldID()) + }) + } else { + loadFieldIDs = typeutil.NewSet(lo.Map(loadSchema.GetFields(), func(field *schemapb.FieldSchema, _ int) int64 { return field.GetFieldID() })...) + } + + schemaBlob, err := proto.Marshal(loadSchema) if err != nil { log.Warn("marshal schema failed", zap.Error(err)) return nil @@ -263,6 +280,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM resourceGroup: loadMetaInfo.GetResourceGroup(), refCount: atomic.NewUint32(0), isGpuIndex: isGpuIndex, + loadFields: loadFieldIDs, } for _, partitionID := range loadMetaInfo.GetPartitionIDs() { coll.partitions.Insert(partitionID) diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index ec4127b7d647a..f0b3989cb8a40 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -206,6 +206,14 @@ func (loader *segmentLoader) Load(ctx context.Context, log.Info("no segment to load") return nil, nil } + coll := loader.manager.Collection.Get(collectionID) + // filter field schema which need to be loaded + for _, info := range segments { + info.BinlogPaths = lo.Filter(info.GetBinlogPaths(), func(fbl *datapb.FieldBinlog, _ int) bool { + return coll.loadFields.Contain(fbl.GetFieldID()) || common.IsSystemField(fbl.GetFieldID()) + }) + } + // Filter out loaded & loading segments infos := loader.prepare(ctx, segmentType, segments...) defer loader.unregister(infos...) @@ -220,7 +228,7 @@ func (loader *segmentLoader) Load(ctx context.Context, var err error var requestResourceResult requestResourceResult - coll := loader.manager.Collection.Get(collectionID) + if !isLazyLoad(coll, segmentType) { // Check memory & storage limit // no need to check resource for lazy load here diff --git a/pkg/common/common.go b/pkg/common/common.go index 284baa04e3866..b173e0217b610 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -23,9 +23,12 @@ import ( "strings" "github.com/cockroachdb/errors" + "github.com/samber/lo" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/log" ) // system field id: @@ -167,6 +170,7 @@ const ( MmapEnabledKey = "mmap.enabled" LazyLoadEnableKey = "lazyload.enabled" PartitionKeyIsolationKey = "partitionkey.isolation" + FieldSkipLoadKey = "field.skipLoad" ) const ( @@ -327,3 +331,26 @@ func CollectionLevelResourceGroups(kvs []*commonpb.KeyValuePair) ([]string, erro return nil, fmt.Errorf("collection property not found: %s", CollectionReplicaNumber) } + +// GetCollectionLoadFields returns the load field ids according to the type params. +func GetCollectionLoadFields(schema *schemapb.CollectionSchema) []int64 { + return lo.FilterMap(schema.GetFields(), func(field *schemapb.FieldSchema, _ int) (int64, bool) { + v, err := ShouldFieldBeLoaded(field.GetTypeParams()) + if err != nil { + log.Warn("type param parse skip load failed", zap.Error(err)) + // if configuration cannot be parsed, ignore it and load field + return field.GetFieldID(), true + } + return field.GetFieldID(), v + }) +} + +func ShouldFieldBeLoaded(kvs []*commonpb.KeyValuePair) (bool, error) { + for _, kv := range kvs { + if kv.GetKey() == FieldSkipLoadKey { + val, err := strconv.ParseBool(kv.GetValue()) + return !val, err + } + } + return true, nil +} diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go index 11ca8949f1622..7e77b782f38eb 100644 --- a/pkg/common/common_test.go +++ b/pkg/common/common_test.go @@ -149,3 +149,31 @@ func TestCommonPartitionKeyIsolation(t *testing.T) { assert.False(t, res) }) } + +func TestShouldFieldBeLoaded(t *testing.T) { + type testCase struct { + tag string + input []*commonpb.KeyValuePair + expectOutput bool + expectError bool + } + + testcases := []testCase{ + {tag: "no_params", expectOutput: true}, + {tag: "skipload_true", input: []*commonpb.KeyValuePair{{Key: FieldSkipLoadKey, Value: "true"}}, expectOutput: false}, + {tag: "skipload_false", input: []*commonpb.KeyValuePair{{Key: FieldSkipLoadKey, Value: "false"}}, expectOutput: true}, + {tag: "bad_skip_load_value", input: []*commonpb.KeyValuePair{{Key: FieldSkipLoadKey, Value: "abc"}}, expectError: true}, + } + + for _, tc := range testcases { + t.Run(tc.tag, func(t *testing.T) { + result, err := ShouldFieldBeLoaded(tc.input) + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectOutput, result) + } + }) + } +} diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index d3d8f912797b6..cc0ce10a0239f 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -257,14 +257,14 @@ type SchemaHelper struct { primaryKeyOffset int partitionKeyOffset int dynamicFieldOffset int + loadFields Set[int64] } -// CreateSchemaHelper returns a new SchemaHelper object -func CreateSchemaHelper(schema *schemapb.CollectionSchema) (*SchemaHelper, error) { +func CreateSchemaHelperWithLoadFields(schema *schemapb.CollectionSchema, loadFields []int64) (*SchemaHelper, error) { if schema == nil { return nil, errors.New("schema is nil") } - schemaHelper := SchemaHelper{schema: schema, nameOffset: make(map[string]int), idOffset: make(map[int64]int), primaryKeyOffset: -1, partitionKeyOffset: -1, dynamicFieldOffset: -1} + schemaHelper := SchemaHelper{schema: schema, nameOffset: make(map[string]int), idOffset: make(map[int64]int), primaryKeyOffset: -1, partitionKeyOffset: -1, dynamicFieldOffset: -1, loadFields: NewSet(loadFields...)} for offset, field := range schema.Fields { if _, ok := schemaHelper.nameOffset[field.Name]; ok { return nil, fmt.Errorf("duplicated fieldName: %s", field.Name) @@ -298,6 +298,11 @@ func CreateSchemaHelper(schema *schemapb.CollectionSchema) (*SchemaHelper, error return &schemaHelper, nil } +// CreateSchemaHelper returns a new SchemaHelper object +func CreateSchemaHelper(schema *schemapb.CollectionSchema) (*SchemaHelper, error) { + return CreateSchemaHelperWithLoadFields(schema, nil) +} + // GetPrimaryKeyField returns the schema of the primary key func (helper *SchemaHelper) GetPrimaryKeyField() (*schemapb.FieldSchema, error) { if helper.primaryKeyOffset == -1 { @@ -338,7 +343,20 @@ func (helper *SchemaHelper) GetFieldFromNameDefaultJSON(fieldName string) (*sche if !ok { return helper.getDefaultJSONField(fieldName) } - return helper.schema.Fields[offset], nil + fieldSchema := helper.schema.Fields[offset] + if !helper.IsFieldLoaded(fieldSchema.GetFieldID()) { + return nil, errors.Newf("field %s is not loaded", fieldSchema) + } + return fieldSchema, nil +} + +// GetFieldFromNameDefaultJSON returns whether is field loaded. +// If load fields is not provided, treated as loaded +func (helper *SchemaHelper) IsFieldLoaded(fieldID int64) bool { + if len(helper.loadFields) == 0 { + return true + } + return helper.loadFields.Contain(fieldID) } func (helper *SchemaHelper) getDefaultJSONField(fieldName string) (*schemapb.FieldSchema, error) {