Skip to content

Commit

Permalink
fix: use the default partition for the limit quota when the request p…
Browse files Browse the repository at this point in the history
…artition name is empty (#38005)

- issue: #37685

Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG authored Nov 27, 2024
1 parent 49ee46e commit 302650a
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 6 deletions.
11 changes: 11 additions & 0 deletions internal/proxy/meta_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ type partitionInfo struct {
partitionID typeutil.UniqueID
createdTimestamp uint64
createdUtcTimestamp uint64
isDefault bool
}

func (info *collectionInfo) isCollectionCached() bool {
Expand Down Expand Up @@ -427,12 +428,14 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string,
return nil, merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String())
}

defaultPartitionName := Params.CommonCfg.DefaultPartitionName.GetValue()
infos := lo.Map(partitions.GetPartitionIDs(), func(partitionID int64, idx int) *partitionInfo {
return &partitionInfo{
name: partitions.PartitionNames[idx],
partitionID: partitions.PartitionIDs[idx],
createdTimestamp: partitions.CreatedTimestamps[idx],
createdUtcTimestamp: partitions.CreatedUtcTimestamps[idx],
isDefault: partitions.PartitionNames[idx] == defaultPartitionName,
}
})

Expand Down Expand Up @@ -630,6 +633,14 @@ func (m *MetaCache) GetPartitionInfo(ctx context.Context, database, collectionNa
return nil, err
}

if partitionName == "" {
for _, info := range partitions.partitionInfos {
if info.isDefault {
return info, nil
}
}
}

info, ok := partitions.name2Info[partitionName]
if !ok {
return nil, merr.WrapErrPartitionNotFound(partitionName)
Expand Down
8 changes: 7 additions & 1 deletion internal/proxy/rate_limit_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,13 @@ func getCollectionAndPartitionID(ctx context.Context, r reqPartName) (int64, map
return 0, nil, err
}
if r.GetPartitionName() == "" {
return db.dbID, map[int64][]int64{collectionID: {}}, nil
collectionSchema, err := globalMetaCache.GetCollectionSchema(ctx, r.GetDbName(), r.GetCollectionName())
if err != nil {
return 0, nil, err
}
if collectionSchema.IsPartitionKeyCollection() {
return db.dbID, map[int64][]int64{collectionID: {}}, nil
}
}
part, err := globalMetaCache.GetPartitionInfo(ctx, r.GetDbName(), r.GetCollectionName(), r.GetPartitionName())
if err != nil {
Expand Down
41 changes: 39 additions & 2 deletions internal/proxy/rate_limit_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ func TestRateLimitInterceptor(t *testing.T) {
dbID: 100,
createdTimestamp: 1,
}, nil)
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil)
globalMetaCache = mockCache

limiter := limiterMock{rate: 100}
Expand Down Expand Up @@ -437,6 +438,41 @@ func TestGetInfo(t *testing.T) {
}
})

t.Run("fail to get collection schema", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
dbID: 100,
createdTimestamp: 1,
}, nil).Once()
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once()
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock error")).Once()

_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
DbName: "foo",
CollectionName: "coo",
})
assert.Error(t, err)
})

t.Run("partition key mode", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
dbID: 100,
createdTimestamp: 1,
}, nil).Once()
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once()
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{
hasPartitionKeyField: true,
}, nil).Once()

db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
DbName: "foo",
CollectionName: "coo",
})
assert.NoError(t, err)
assert.Equal(t, int64(100), db)
assert.NotNil(t, col2par[1])
assert.Equal(t, 0, len(col2par[1]))
})

t.Run("fail to get partition", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{
dbID: 100,
Expand Down Expand Up @@ -467,11 +503,12 @@ func TestGetInfo(t *testing.T) {
dbID: 100,
createdTimestamp: 1,
}, nil).Times(3)
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil).Times(1)
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(int64(10), nil).Times(3)
mockCache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
name: "p1",
partitionID: 100,
}, nil).Twice()
}, nil).Times(3)
{
db, col2par, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
DbName: "foo",
Expand All @@ -491,7 +528,7 @@ func TestGetInfo(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, int64(100), db)
assert.NotNil(t, col2par[10])
assert.Equal(t, 0, len(col2par[10]))
assert.Equal(t, int64(100), col2par[10][0])
}
{
db, col2par, err := getCollectionAndPartitionIDs(ctx, &milvuspb.SearchRequest{
Expand Down
7 changes: 6 additions & 1 deletion internal/proxy/task_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,12 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
// insert to _default partition
partitionTag := it.insertMsg.GetPartitionName()
if len(partitionTag) <= 0 {
partitionTag = Params.CommonCfg.DefaultPartitionName.GetValue()
pinfo, err := globalMetaCache.GetPartitionInfo(ctx, it.insertMsg.GetDbName(), collectionName, "")
if err != nil {
log.Warn("get partition info failed", zap.String("collectionName", collectionName), zap.Error(err))
return err
}
partitionTag = pinfo.name
it.insertMsg.PartitionName = partitionTag
}

Expand Down
198 changes: 198 additions & 0 deletions internal/proxy/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3651,6 +3651,204 @@ func TestPartitionKey(t *testing.T) {
})
}

func TestDefaultPartition(t *testing.T) {
rc := NewRootCoordMock()

defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()

ctx := context.Background()

mgr := newShardClientMgr()
err := InitMetaCache(ctx, rc, qc, mgr)
assert.NoError(t, err)

shardsNum := common.DefaultShardsNum
prefix := "TestInsertTaskWithPartitionKey"
collectionName := prefix + funcutil.GenRandomStr()

fieldName2Type := make(map[string]schemapb.DataType)
fieldName2Type["int64_field"] = schemapb.DataType_Int64
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)

t.Run("create collection", func(t *testing.T) {
createCollectionTask := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: "",
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
err = createCollectionTask.PreExecute(ctx)
assert.NoError(t, err)
err = createCollectionTask.Execute(ctx)
assert.NoError(t, err)
})

collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)

dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()

_, err = chMgr.getOrCreateDmlStream(collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)

interval := time.Millisecond * 10
tso := newMockTsoAllocator()

ticker := newChannelsTimeTicker(ctx, interval, []string{}, newGetStatisticsFunc(pchans), tso)
_ = ticker.start()
defer ticker.close()

idAllocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID())
assert.NoError(t, err)
_ = idAllocator.Start()
defer idAllocator.Close()

segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
segAllocator.Init()
_ = segAllocator.Start()
defer segAllocator.Close()

nb := 10
fieldID := common.StartOfUserFieldID
fieldDatas := make([]*schemapb.FieldData, 0)
for fieldName, dataType := range fieldName2Type {
fieldData := generateFieldData(dataType, fieldName, nb)
fieldData.FieldId = int64(fieldID)
fieldDatas = append(fieldDatas, generateFieldData(dataType, fieldName, nb))
fieldID++
}

t.Run("Insert", func(t *testing.T) {
it := &insertTask{
insertMsg: &BaseInsertTask{
BaseMsg: msgstream.BaseMsg{},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 0,
SourceID: paramtable.GetNodeID(),
},
CollectionName: collectionName,
FieldsData: fieldDatas,
NumRows: uint64(nb),
Version: msgpb.InsertDataVersion_ColumnBased,
},
},

Condition: NewTaskCondition(ctx),
ctx: ctx,
result: &milvuspb.MutationResult{
Status: merr.Success(),
IDs: nil,
SuccIndex: nil,
ErrIndex: nil,
Acknowledged: false,
InsertCnt: 0,
DeleteCnt: 0,
UpsertCnt: 0,
Timestamp: 0,
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
}

it.insertMsg.PartitionName = ""
assert.NoError(t, it.OnEnqueue())
assert.NoError(t, it.PreExecute(ctx))
assert.NoError(t, it.Execute(ctx))
assert.NoError(t, it.PostExecute(ctx))
})

t.Run("Upsert", func(t *testing.T) {
hash := testutils.GenerateHashKeys(nb)
ut := &upsertTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
baseMsg: msgstream.BaseMsg{
HashValues: hash,
},
req: &milvuspb.UpsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Upsert),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
CollectionName: collectionName,
FieldsData: fieldDatas,
NumRows: uint32(nb),
},

result: &milvuspb.MutationResult{
Status: merr.Success(),
IDs: &schemapb.IDs{
IdField: nil,
},
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
}

ut.req.PartitionName = ""
assert.NoError(t, ut.OnEnqueue())
assert.NoError(t, ut.PreExecute(ctx))
assert.NoError(t, ut.Execute(ctx))
assert.NoError(t, ut.PostExecute(ctx))
})

t.Run("delete", func(t *testing.T) {
dt := &deleteTask{
Condition: NewTaskCondition(ctx),
req: &milvuspb.DeleteRequest{
CollectionName: collectionName,
Expr: "int64_field in [0, 1]",
},
ctx: ctx,
primaryKeys: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{0, 1}}},
},
idAllocator: idAllocator,
chMgr: chMgr,
chTicker: ticker,
collectionID: collectionID,
vChannels: []string{"test-channel"},
}

dt.req.PartitionName = ""
assert.NoError(t, dt.PreExecute(ctx))
assert.NoError(t, dt.Execute(ctx))
assert.NoError(t, dt.PostExecute(ctx))
})
}

func TestClusteringKey(t *testing.T) {
rc := NewRootCoordMock()

Expand Down
8 changes: 6 additions & 2 deletions internal/proxy/task_upsert.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,12 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
// insert to _default partition
partitionTag := it.req.GetPartitionName()
if len(partitionTag) <= 0 {
partitionTag = Params.CommonCfg.DefaultPartitionName.GetValue()
it.req.PartitionName = partitionTag
pinfo, err := globalMetaCache.GetPartitionInfo(ctx, it.req.GetDbName(), collectionName, "")
if err != nil {
log.Warn("get partition info failed", zap.String("collectionName", collectionName), zap.Error(err))
return err
}
it.req.PartitionName = pinfo.name
}
}

Expand Down

0 comments on commit 302650a

Please sign in to comment.