Skip to content

Commit

Permalink
enhance: [GoSDK] Add load option for field partial load (#35920)
Browse files Browse the repository at this point in the history
Related to #35415

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Sep 3, 2024
1 parent 9da8652 commit 69b1eea
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 14 deletions.
57 changes: 45 additions & 12 deletions client/maintenance_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,19 @@ type LoadCollectionOption interface {
}

type loadCollectionOption struct {
collectionName string
interval time.Duration
replicaNum int
collectionName string
interval time.Duration
replicaNum int
loadFields []string
skipLoadDynamicField bool
}

func (opt *loadCollectionOption) Request() *milvuspb.LoadCollectionRequest {
return &milvuspb.LoadCollectionRequest{
CollectionName: opt.collectionName,
ReplicaNumber: int32(opt.replicaNum),
CollectionName: opt.collectionName,
ReplicaNumber: int32(opt.replicaNum),
LoadFields: opt.loadFields,
SkipLoadDynamicField: opt.skipLoadDynamicField,
}
}

Expand All @@ -49,6 +53,16 @@ func (opt *loadCollectionOption) WithReplica(num int) *loadCollectionOption {
return opt
}

func (opt *loadCollectionOption) WithLoadFields(loadFields ...string) *loadCollectionOption {
opt.loadFields = loadFields
return opt
}

func (opt *loadCollectionOption) WithSkipLoadDynamicField(skipFlag bool) *loadCollectionOption {
opt.skipLoadDynamicField = skipFlag
return opt
}

func NewLoadCollectionOption(collectionName string) *loadCollectionOption {
return &loadCollectionOption{
collectionName: collectionName,
Expand All @@ -65,24 +79,43 @@ type LoadPartitionsOption interface {
var _ LoadPartitionsOption = (*loadPartitionsOption)(nil)

type loadPartitionsOption struct {
collectionName string
partitionNames []string
interval time.Duration
replicaNum int
collectionName string
partitionNames []string
interval time.Duration
replicaNum int
loadFields []string
skipLoadDynamicField bool
}

func (opt *loadPartitionsOption) Request() *milvuspb.LoadPartitionsRequest {
return &milvuspb.LoadPartitionsRequest{
CollectionName: opt.collectionName,
PartitionNames: opt.partitionNames,
ReplicaNumber: int32(opt.replicaNum),
CollectionName: opt.collectionName,
PartitionNames: opt.partitionNames,
ReplicaNumber: int32(opt.replicaNum),
LoadFields: opt.loadFields,
SkipLoadDynamicField: opt.skipLoadDynamicField,
}
}

func (opt *loadPartitionsOption) CheckInterval() time.Duration {
return opt.interval
}

func (opt *loadPartitionsOption) WithReplica(num int) *loadPartitionsOption {
opt.replicaNum = num
return opt
}

func (opt *loadPartitionsOption) WithLoadFields(loadFields ...string) *loadPartitionsOption {
opt.loadFields = loadFields
return opt
}

func (opt *loadPartitionsOption) WithSkipLoadDynamicField(skipFlag bool) *loadPartitionsOption {
opt.skipLoadDynamicField = skipFlag
return opt
}

func NewLoadPartitionsOption(collectionName string, partitionsNames []string) *loadPartitionsOption {
return &loadPartitionsOption{
collectionName: collectionName,
Expand Down
21 changes: 19 additions & 2 deletions client/maintenance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package client
import (
"context"
"fmt"
"math/rand"
"testing"
"time"

Expand All @@ -41,10 +42,15 @@ func (s *MaintenanceSuite) TestLoadCollection() {
defer cancel()
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
fieldNames := []string{"id", "part", "vector"}
replicaNum := rand.Intn(3) + 1

done := atomic.NewBool(false)
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lcr *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) {
s.Equal(collectionName, lcr.GetCollectionName())
s.ElementsMatch(fieldNames, lcr.GetLoadFields())
s.True(lcr.SkipLoadDynamicField)
s.EqualValues(replicaNum, lcr.GetReplicaNumber())
return merr.Success(), nil
}).Once()
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
Expand All @@ -62,7 +68,10 @@ func (s *MaintenanceSuite) TestLoadCollection() {
})
defer s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Unset()

task, err := s.client.LoadCollection(ctx, NewLoadCollectionOption(collectionName))
task, err := s.client.LoadCollection(ctx, NewLoadCollectionOption(collectionName).
WithReplica(replicaNum).
WithLoadFields(fieldNames...).
WithSkipLoadDynamicField(true))
s.NoError(err)

ch := make(chan struct{})
Expand Down Expand Up @@ -103,11 +112,16 @@ func (s *MaintenanceSuite) TestLoadPartitions() {
s.Run("success", func() {
collectionName := fmt.Sprintf("coll_%s", s.randString(6))
partitionName := fmt.Sprintf("part_%s", s.randString(6))
fieldNames := []string{"id", "part", "vector"}
replicaNum := rand.Intn(3) + 1

done := atomic.NewBool(false)
s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, lpr *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) {
s.Equal(collectionName, lpr.GetCollectionName())
s.ElementsMatch([]string{partitionName}, lpr.GetPartitionNames())
s.ElementsMatch(fieldNames, lpr.GetLoadFields())
s.True(lpr.SkipLoadDynamicField)
s.EqualValues(replicaNum, lpr.GetReplicaNumber())
return merr.Success(), nil
}).Once()
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, glpr *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
Expand All @@ -126,7 +140,10 @@ func (s *MaintenanceSuite) TestLoadPartitions() {
})
defer s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Unset()

task, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, []string{partitionName}))
task, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, []string{partitionName}).
WithReplica(replicaNum).
WithLoadFields(fieldNames...).
WithSkipLoadDynamicField(true))
s.NoError(err)

ch := make(chan struct{})
Expand Down

0 comments on commit 69b1eea

Please sign in to comment.