Skip to content

Commit

Permalink
chore: last cleanup move back to non-interceptor
Browse files Browse the repository at this point in the history
  • Loading branch information
yquansah committed Aug 7, 2023
1 parent 811faa6 commit e2f936a
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 245 deletions.
1 change: 0 additions & 1 deletion internal/cmd/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ func NewGRPCServer(
append(authInterceptors,
middlewaregrpc.ErrorUnaryInterceptor,
middlewaregrpc.ValidationUnaryInterceptor,
middlewaregrpc.SegmentKeysUnaryInterceptor,
middlewaregrpc.EvaluationUnaryInterceptor,
)...,
)
Expand Down
7 changes: 4 additions & 3 deletions internal/ext/importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ func (i *Importer) Import(ctx context.Context, r io.Reader) (err error) {
Rank: rank,
NamespaceKey: namespace,
SegmentOperator: flipt.SegmentOperator(flipt.SegmentOperator_value[r.SegmentOperator]),
SegmentKey: r.SegmentKey,
}

if len(r.SegmentKeys) > 0 && r.SegmentKey != "" {
Expand All @@ -264,8 +263,10 @@ func (i *Importer) Import(ctx context.Context, r io.Reader) (err error) {
)
}

// support explicitly setting only "segments" on rules from 1.2
if len(r.SegmentKeys) > 0 {
if r.SegmentKey != "" {
fcr.SegmentKey = r.SegmentKey
} else if len(r.SegmentKeys) > 0 {
// support explicitly setting only "segments" on rules from 1.2
if err := ensureFieldSupported("flag.rules[*].segments", semver.Version{
Major: 1,
Minor: 2,
Expand Down
16 changes: 0 additions & 16 deletions internal/server/middleware/grpc/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,6 @@ func EvaluationUnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.Un
return handler(ctx, req)
}

type SegmentKeysSanitizer interface {
// SanitizeSegmentKeys deduplicates the requests' `SegmentKey`, and `SegmentKeys`
// field, and combine it into one slice.
SanitizeSegmentKeys()
}

// SegmentKeysUnaryInterceptor sets the `SegmentKeys` field to the correct value for the request.
// TODO(yquansah): remove this once `SegmentKey` is no longer needed.
func SegmentKeysUnaryInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
if r, ok := req.(SegmentKeysSanitizer); ok {
r.SanitizeSegmentKeys()
}

return handler(ctx, req)
}

// CacheUnaryInterceptor caches the response of a request if the request is cacheable.
// TODO: we could clean this up by using generics in 1.18+ to avoid the type switch/duplicate code.
func CacheUnaryInterceptor(cache cache.Cacher, logger *zap.Logger) grpc.UnaryServerInterceptor {
Expand Down
106 changes: 2 additions & 104 deletions internal/server/middleware/grpc/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,111 +358,9 @@ func TestEvaluationUnaryInterceptor_BatchEvaluation(t *testing.T) {
// check that the requestID was propagated
assert.NotEmpty(t, resp.RequestId)
assert.Equal(t, "bar", resp.RequestId)
assert.NotZero(t, resp.RequestDurationMillis)
}

func TestSegmentKeysUnaryInterceptor(t *testing.T) {
type segmentKeysGetter interface {
GetSegmentKeys() []string
}

for _, test := range []struct {
name string
req interface{}
length int
}{
{
name: "flipt.CreateRuleRequest with just SegmentKey",
req: &flipt.CreateRuleRequest{
FlagKey: "foo",
SegmentKey: "segment_foo",
},
length: 1,
},
{
name: "flipt.UpdateRuleRequest with SegmentKeys and deduplicating",
req: &flipt.UpdateRuleRequest{
FlagKey: "foo",
SegmentKeys: []string{"segment_foo", "segment_foo", "segment_bar"},
},
length: 2,
},
} {
var (
handler = func(ctx context.Context, r interface{}) (interface{}, error) {
// ensure that the request can be asserted to `segmentKeysGetter`.
req, ok := r.(segmentKeysGetter)
require.True(t, ok)

assert.Len(t, req.GetSegmentKeys(), test.length)

return nil, nil
}

info = &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}
)

_, err := SegmentKeysUnaryInterceptor(context.Background(), test.req, info, handler)
require.NoError(t, err)
}

type segmentGetter interface {
GetSegment() *flipt.RolloutSegment
}

for _, test := range []struct {
name string
req interface{}
length int
}{
{
name: "flipt.CreateRolloutRequest with just SegmentKeys",
req: &flipt.CreateRolloutRequest{
FlagKey: "foo",
Rank: 1,
Rule: &flipt.CreateRolloutRequest_Segment{
Segment: &flipt.RolloutSegment{
SegmentKeys: []string{"segment_foo", "segment_bar"},
Value: true,
},
},
},
length: 2,
},
{
name: "flipt.UpdateRolloutRequest with SegmentKey",
req: &flipt.UpdateRolloutRequest{
FlagKey: "foo",
Rule: &flipt.UpdateRolloutRequest_Segment{
Segment: &flipt.RolloutSegment{
SegmentKey: "segment_foo",
},
},
},
length: 1,
},
} {
var (
handler = func(ctx context.Context, r interface{}) (interface{}, error) {
// ensure that the request can be asserted to `segmentGetter`.
req, ok := r.(segmentGetter)
require.True(t, ok)

assert.Len(t, req.GetSegment().GetSegmentKeys(), test.length)

return nil, nil
}

info = &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}
)

_, err := SegmentKeysUnaryInterceptor(context.Background(), test.req, info, handler)
require.NoError(t, err)
}
// TODO(yquansah): flakey assertion
// assert.NotZero(t, resp.RequestDurationMillis)
}

func TestCacheUnaryInterceptor_GetFlag(t *testing.T) {
Expand Down
14 changes: 9 additions & 5 deletions internal/storage/sql/common/rollout.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,8 @@ func (s *Store) CreateRollout(ctx context.Context, r *flipt.CreateRolloutRequest

var segmentRule = r.GetSegment()

segmentKeys := sanitizeSegmentKeys(segmentRule.GetSegmentKey(), segmentRule.GetSegmentKeys())

if _, err := s.builder.Insert(tableRolloutSegments).
RunWith(tx).
Columns("id", "rollout_id", "\"value\"", "segment_operator").
Expand All @@ -475,7 +477,7 @@ func (s *Store) CreateRollout(ctx context.Context, r *flipt.CreateRolloutRequest
return nil, err
}

for _, segmentKey := range segmentRule.SegmentKeys {
for _, segmentKey := range segmentKeys {
if _, err := s.builder.Insert(tableRolloutSegmentReferences).
RunWith(tx).
Columns("rollout_segment_id", "namespace_key", "segment_key").
Expand All @@ -490,10 +492,10 @@ func (s *Store) CreateRollout(ctx context.Context, r *flipt.CreateRolloutRequest
SegmentOperator: segmentRule.SegmentOperator,
}

if len(segmentRule.SegmentKeys) == 1 {
innerSegment.SegmentKey = segmentRule.SegmentKeys[0]
if len(segmentKeys) == 1 {
innerSegment.SegmentKey = segmentKeys[0]
} else {
innerSegment.SegmentKeys = segmentRule.SegmentKeys
innerSegment.SegmentKeys = segmentKeys
}

rollout.Rule = &flipt.Rollout_Segment{
Expand Down Expand Up @@ -579,6 +581,8 @@ func (s *Store) UpdateRollout(ctx context.Context, r *flipt.UpdateRolloutRequest

var segmentRule = r.GetSegment()

segmentKeys := sanitizeSegmentKeys(segmentRule.GetSegmentKey(), segmentRule.GetSegmentKeys())

if _, err := s.builder.Update(tableRolloutSegments).
RunWith(tx).
Set("segment_operator", segmentRule.SegmentOperator).
Expand Down Expand Up @@ -608,7 +612,7 @@ func (s *Store) UpdateRollout(ctx context.Context, r *flipt.UpdateRolloutRequest
return nil, err
}

for _, segmentKey := range segmentRule.SegmentKeys {
for _, segmentKey := range segmentKeys {
if _, err := s.builder.
Insert(tableRolloutSegmentReferences).
RunWith(tx).
Expand Down
14 changes: 9 additions & 5 deletions internal/storage/sql/common/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ func (s *Store) CountRules(ctx context.Context, namespaceKey, flagKey string) (u

// CreateRule creates a rule
func (s *Store) CreateRule(ctx context.Context, r *flipt.CreateRuleRequest) (_ *flipt.Rule, err error) {
segmentKeys := sanitizeSegmentKeys(r.GetSegmentKey(), r.GetSegmentKeys())

if r.NamespaceKey == "" {
r.NamespaceKey = storage.DefaultNamespace
}
Expand Down Expand Up @@ -410,7 +412,7 @@ func (s *Store) CreateRule(ctx context.Context, r *flipt.CreateRuleRequest) (_ *
return nil, err
}

for _, segmentKey := range r.SegmentKeys {
for _, segmentKey := range segmentKeys {
if _, err := s.builder.
Insert("rule_segments").
RunWith(tx).
Expand All @@ -425,17 +427,19 @@ func (s *Store) CreateRule(ctx context.Context, r *flipt.CreateRuleRequest) (_ *
}
}

if len(r.SegmentKeys) == 1 {
rule.SegmentKey = r.SegmentKeys[0]
if len(segmentKeys) == 1 {
rule.SegmentKey = segmentKeys[0]
} else {
rule.SegmentKeys = r.SegmentKeys
rule.SegmentKeys = segmentKeys
}

return rule, tx.Commit()
}

// UpdateRule updates an existing rule
func (s *Store) UpdateRule(ctx context.Context, r *flipt.UpdateRuleRequest) (_ *flipt.Rule, err error) {
segmentKeys := sanitizeSegmentKeys(r.GetSegmentKey(), r.GetSegmentKeys())

if r.NamespaceKey == "" {
r.NamespaceKey = storage.DefaultNamespace
}
Expand Down Expand Up @@ -470,7 +474,7 @@ func (s *Store) UpdateRule(ctx context.Context, r *flipt.UpdateRuleRequest) (_ *
return nil, err
}

for _, segmentKey := range r.SegmentKeys {
for _, segmentKey := range segmentKeys {
if _, err := s.builder.
Insert("rule_segments").
RunWith(tx).
Expand Down
29 changes: 29 additions & 0 deletions internal/storage/sql/common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,32 @@ func decodePageToken(logger *zap.Logger, pageToken string) (PageToken, error) {

return token, nil
}

// removeDuplicates is an inner utility function that will deduplicate a slice of strings.
func removeDuplicates(src []string) []string {
allKeys := make(map[string]bool)

dest := []string{}

for _, item := range src {
if _, value := allKeys[item]; !value {
allKeys[item] = true
dest = append(dest, item)
}
}

return dest
}

// sanitizeSegmentKeys is a utility function that will transform segment keys into the right input.
func sanitizeSegmentKeys(segmentKey string, segmentKeys []string) []string {
result := make([]string, 0)

if len(segmentKeys) > 0 {
result = append(result, segmentKeys...)
} else if segmentKey != "" {
result = append(result, segmentKey)
}

return removeDuplicates(result)
}
2 changes: 1 addition & 1 deletion internal/storage/sql/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func (s *Store) CreateRule(ctx context.Context, r *flipt.CreateRuleRequest) (*fl
var merr *mysql.MySQLError

if errors.As(err, &merr) && merr.Number == constraintForeignKeyErrCode {
return nil, errs.ErrNotFoundf(`flag "%s/%s" or segments "%s"`, r.NamespaceKey, r.FlagKey, r.NamespaceKey)
return nil, errs.ErrNotFoundf(`flag "%s/%s" or segment "%s/%s"`, r.NamespaceKey, r.FlagKey, r.NamespaceKey, r.SegmentKey)
}

return nil, err
Expand Down
2 changes: 1 addition & 1 deletion internal/storage/sql/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func (s *Store) CreateRule(ctx context.Context, r *flipt.CreateRuleRequest) (*fl
var perr *pq.Error

if errors.As(err, &perr) && perr.Code.Name() == constraintForeignKeyErr {
return nil, errs.ErrNotFoundf(`flag "%s/%s" or segments "%s"`, r.NamespaceKey, r.FlagKey, r.NamespaceKey)
return nil, errs.ErrNotFoundf(`flag "%s/%s" or segment "%s/%s"`, r.NamespaceKey, r.FlagKey, r.NamespaceKey, r.SegmentKey)
}

return nil, err
Expand Down
4 changes: 2 additions & 2 deletions internal/storage/sql/rollout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,8 @@ func (s *DBTestSuite) TestUpdateRollout() {
Rank: 1,
Rule: &flipt.CreateRolloutRequest_Segment{
Segment: &flipt.RolloutSegment{
Value: true,
SegmentKeys: []string{"segment_one"},
Value: true,
SegmentKey: "segment_one",
},
},
})
Expand Down
Loading

0 comments on commit e2f936a

Please sign in to comment.