From 302650ae0e2fa2979a57a1b71bc9a49f9c91766b Mon Sep 17 00:00:00 2001 From: SimFG Date: Wed, 27 Nov 2024 11:00:36 +0800 Subject: [PATCH] fix: use the default partition for the limit quota when the request partition name is empty (#38005) - issue: #37685 Signed-off-by: SimFG --- internal/proxy/meta_cache.go | 11 + internal/proxy/rate_limit_interceptor.go | 8 +- internal/proxy/rate_limit_interceptor_test.go | 41 +++- internal/proxy/task_insert.go | 7 +- internal/proxy/task_test.go | 198 ++++++++++++++++++ internal/proxy/task_upsert.go | 8 +- 6 files changed, 267 insertions(+), 6 deletions(-) diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index daea35b137909..57543a2abd64a 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -264,6 +264,7 @@ type partitionInfo struct { partitionID typeutil.UniqueID createdTimestamp uint64 createdUtcTimestamp uint64 + isDefault bool } func (info *collectionInfo) isCollectionCached() bool { @@ -427,12 +428,14 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string, return nil, merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String()) } + defaultPartitionName := Params.CommonCfg.DefaultPartitionName.GetValue() infos := lo.Map(partitions.GetPartitionIDs(), func(partitionID int64, idx int) *partitionInfo { return &partitionInfo{ name: partitions.PartitionNames[idx], partitionID: partitions.PartitionIDs[idx], createdTimestamp: partitions.CreatedTimestamps[idx], createdUtcTimestamp: partitions.CreatedUtcTimestamps[idx], + isDefault: partitions.PartitionNames[idx] == defaultPartitionName, } }) @@ -630,6 +633,14 @@ func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionNa return nil, err } + if partitionName == "" { + for _, info := range partitions.partitionInfos { + if info.isDefault { + return info, nil + } + } + } + info, ok := partitions.name2Info[partitionName] if !ok { return nil, merr.WrapErrPartitionNotFound(partitionName) diff --git a/internal/proxy/rate_limit_interceptor.go b/internal/proxy/rate_limit_interceptor.go index ca4b99c23e30b..996961fd8000d 100644 --- a/internal/proxy/rate_limit_interceptor.go +++ b/internal/proxy/rate_limit_interceptor.go @@ -84,7 +84,13 @@ func getCollectionAndPartitionID(ctx context.Context, r reqPartName) (int64, map return 0, nil, err } if r.GetPartitionName() == "" { - return db.dbID, map[int64][]int64{collectionID: {}}, nil + collectionSchema, err := globalMetaCache.GetCollectionSchema(ctx, r.GetDbName(), r.GetCollectionName()) + if err != nil { + return 0, nil, err + } + if collectionSchema.IsPartitionKeyCollection() { + return db.dbID, map[int64][]int64{collectionID: {}}, nil + } } part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), r.GetPartitionName()) if err != nil { diff --git a/internal/proxy/rate_limit_interceptor_test.go b/internal/proxy/rate_limit_interceptor_test.go index 56e9345c85abd..22a5c98c326c9 100644 --- a/internal/proxy/rate_limit_interceptor_test.go +++ b/internal/proxy/rate_limit_interceptor_test.go @@ -299,6 +299,7 @@ func TestRateLimitInterceptor(t *testing.T) { dbID: 100, createdTimestamp: 1, }, nil) + mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil) globalMetaCache = mockCache limiter := limiterMock{rate: 100} @@ -437,6 +438,41 @@ func TestGetInfo(t *testing.T) { } }) + t.Run("fail to get collection schema", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + dbID: 100, + createdTimestamp: 1, + }, nil).Once() + mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once() + mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock error")).Once() + + _, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{ + DbName: "foo", + CollectionName: "coo", + }) + assert.Error(t, err) + }) + + t.Run("partition key mode", func(t *testing.T) { + mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ + dbID: 100, + createdTimestamp: 1, + }, nil).Once() + mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once() + mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{ + hasPartitionKeyField: true, + }, nil).Once() + + db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{ + DbName: "foo", + CollectionName: "coo", + }) + assert.NoError(t, err) + assert.Equal(t, int64(100), db) + assert.NotNil(t, col2par[1]) + assert.Equal(t, 0, len(col2par[1])) + }) + t.Run("fail to get partition", func(t *testing.T) { mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{ dbID: 100, @@ -467,11 +503,12 @@ func TestGetInfo(t *testing.T) { dbID: 100, createdTimestamp: 1, }, nil).Times(3) + mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil).Times(1) mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Times(3) mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{ name: "p1", partitionID: 100, - }, nil).Twice() + }, nil).Times(3) { db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{ DbName: "foo", @@ -491,7 +528,7 @@ func TestGetInfo(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int64(100), db) assert.NotNil(t, col2par[10]) - assert.Equal(t, 0, len(col2par[10])) + assert.Equal(t, int64(100), col2par[10][0]) } { db, col2par, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{ diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index fd86fc9d3c343..620b469d2fa96 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -202,7 +202,12 @@ func (it *insertTask) PreExecute(ctx context.Context) error { // insert to _default partition partitionTag := it.insertMsg.GetPartitionName() if len(partitionTag) <= 0 { - partitionTag = Params.CommonCfg.DefaultPartitionName.GetValue() + pinfo, err := globalMetaCache.GetPartitionInfo(ctx, it.insertMsg.GetDbName(), collectionName, "") + if err != nil { + log.Warn("get partition info failed", zap.String("collectionName", collectionName), zap.Error(err)) + return err + } + partitionTag = pinfo.name it.insertMsg.PartitionName = partitionTag } diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 66c823a9f2e2c..1f74590c16bcc 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -3651,6 +3651,204 @@ func TestPartitionKey(t *testing.T) { }) } +func TestDefaultPartition(t *testing.T) { + rc := NewRootCoordMock() + + defer rc.Close() + qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() + + ctx := context.Background() + + mgr := newShardClientMgr() + err := InitMetaCache(ctx, rc, qc, mgr) + assert.NoError(t, err) + + shardsNum := common.DefaultShardsNum + prefix := "TestInsertTaskWithPartitionKey" + collectionName := prefix + funcutil.GenRandomStr() + + fieldName2Type := make(map[string]schemapb.DataType) + fieldName2Type["int64_field"] = schemapb.DataType_Int64 + fieldName2Type["varChar_field"] = schemapb.DataType_VarChar + fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector + schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + t.Run("create collection", func(t *testing.T) { + createCollectionTask := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), + Timestamp: Timestamp(time.Now().UnixNano()), + }, + DbName: "", + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + err = createCollectionTask.PreExecute(ctx) + assert.NoError(t, err) + err = createCollectionTask.Execute(ctx) + assert.NoError(t, err) + }) + + collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName) + assert.NoError(t, err) + + dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) + factory := newSimpleMockMsgStreamFactory() + chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory) + defer chMgr.removeAllDMLStream() + + _, err = chMgr.getOrCreateDmlStream(collectionID) + assert.NoError(t, err) + pchans, err := chMgr.getChannels(collectionID) + assert.NoError(t, err) + + interval := time.Millisecond * 10 + tso := newMockTsoAllocator() + + ticker := newChannelsTimeTicker(ctx, interval, []string{}, newGetStatisticsFunc(pchans), tso) + _ = ticker.start() + defer ticker.close() + + idAllocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID()) + assert.NoError(t, err) + _ = idAllocator.Start() + defer idAllocator.Close() + + segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1) + assert.NoError(t, err) + segAllocator.Init() + _ = segAllocator.Start() + defer segAllocator.Close() + + nb := 10 + fieldID := common.StartOfUserFieldID + fieldDatas := make([]*schemapb.FieldData, 0) + for fieldName, dataType := range fieldName2Type { + fieldData := generateFieldData(dataType, fieldName, nb) + fieldData.FieldId = int64(fieldID) + fieldDatas = append(fieldDatas, generateFieldData(dataType, fieldName, nb)) + fieldID++ + } + + t.Run("Insert", func(t *testing.T) { + it := &insertTask{ + insertMsg: &BaseInsertTask{ + BaseMsg: msgstream.BaseMsg{}, + InsertRequest: &msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: 0, + SourceID: paramtable.GetNodeID(), + }, + CollectionName: collectionName, + FieldsData: fieldDatas, + NumRows: uint64(nb), + Version: msgpb.InsertDataVersion_ColumnBased, + }, + }, + + Condition: NewTaskCondition(ctx), + ctx: ctx, + result: &milvuspb.MutationResult{ + Status: merr.Success(), + IDs: nil, + SuccIndex: nil, + ErrIndex: nil, + Acknowledged: false, + InsertCnt: 0, + DeleteCnt: 0, + UpsertCnt: 0, + Timestamp: 0, + }, + idAllocator: idAllocator, + segIDAssigner: segAllocator, + chMgr: chMgr, + chTicker: ticker, + vChannels: nil, + pChannels: nil, + schema: nil, + } + + it.insertMsg.PartitionName = "" + assert.NoError(t, it.OnEnqueue()) + assert.NoError(t, it.PreExecute(ctx)) + assert.NoError(t, it.Execute(ctx)) + assert.NoError(t, it.PostExecute(ctx)) + }) + + t.Run("Upsert", func(t *testing.T) { + hash := testutils.GenerateHashKeys(nb) + ut := &upsertTask{ + ctx: ctx, + Condition: NewTaskCondition(ctx), + baseMsg: msgstream.BaseMsg{ + HashValues: hash, + }, + req: &milvuspb.UpsertRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Upsert), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + CollectionName: collectionName, + FieldsData: fieldDatas, + NumRows: uint32(nb), + }, + + result: &milvuspb.MutationResult{ + Status: merr.Success(), + IDs: &schemapb.IDs{ + IdField: nil, + }, + }, + idAllocator: idAllocator, + segIDAssigner: segAllocator, + chMgr: chMgr, + chTicker: ticker, + } + + ut.req.PartitionName = "" + assert.NoError(t, ut.OnEnqueue()) + assert.NoError(t, ut.PreExecute(ctx)) + assert.NoError(t, ut.Execute(ctx)) + assert.NoError(t, ut.PostExecute(ctx)) + }) + + t.Run("delete", func(t *testing.T) { + dt := &deleteTask{ + Condition: NewTaskCondition(ctx), + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + Expr: "int64_field in [0, 1]", + }, + ctx: ctx, + primaryKeys: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{0, 1}}}, + }, + idAllocator: idAllocator, + chMgr: chMgr, + chTicker: ticker, + collectionID: collectionID, + vChannels: []string{"test-channel"}, + } + + dt.req.PartitionName = "" + assert.NoError(t, dt.PreExecute(ctx)) + assert.NoError(t, dt.Execute(ctx)) + assert.NoError(t, dt.PostExecute(ctx)) + }) +} + func TestClusteringKey(t *testing.T) { rc := NewRootCoordMock() diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 154bbba8753b7..3ca4853fa9f53 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -317,8 +317,12 @@ func (it *upsertTask) PreExecute(ctx context.Context) error { // insert to _default partition partitionTag := it.req.GetPartitionName() if len(partitionTag) <= 0 { - partitionTag = Params.CommonCfg.DefaultPartitionName.GetValue() - it.req.PartitionName = partitionTag + pinfo, err := globalMetaCache.GetPartitionInfo(ctx, it.req.GetDbName(), collectionName, "") + if err != nil { + log.Warn("get partition info failed", zap.String("collectionName", collectionName), zap.Error(err)) + return err + } + it.req.PartitionName = pinfo.name } }