Skip to content

Commit

Permalink
fix: introduced RWMutex to flag state to prevent concurrent r/w of map (
Browse files Browse the repository at this point in the history
#370)

## This PR
<!-- add the description of the PR here -->

- Introduces RWMutex on flag state to prevent concurrent read/write of
map.

### Related Issues
<!-- add here the GitHub issue that this PR resolves if applicable -->

Fixes #368 

### Notes
<!-- any additional notes for this PR -->

### Follow-up Tasks
<!-- anything that is related to this PR but not done here should be
noted under this section -->
<!-- if there is a need for a new issue, please link it here -->

### How to test
<!-- if applicable, add testing instructions under this section -->

---------

Signed-off-by: Skye Gill <[email protected]>
  • Loading branch information
skyerus authored Feb 3, 2023
1 parent 7cb20d9 commit 93e356b
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 26 deletions.
6 changes: 6 additions & 0 deletions pkg/eval/fractional_evaluation_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package eval

import (
"sync"
"testing"

"github.com/open-feature/flagd/pkg/logger"
Expand All @@ -10,6 +11,7 @@ import (

func TestFractionalEvaluation(t *testing.T) {
flags := Flags{
mx: &sync.RWMutex{},
Flags: map[string]Flag{
"headerColor": {
State: "ENABLED",
Expand Down Expand Up @@ -113,6 +115,7 @@ func TestFractionalEvaluation(t *testing.T) {
},
"non even split": {
flags: Flags{
mx: &sync.RWMutex{},
Flags: map[string]Flag{
"headerColor": {
State: "ENABLED",
Expand Down Expand Up @@ -164,6 +167,7 @@ func TestFractionalEvaluation(t *testing.T) {
},
"fallback to default variant if no email provided": {
flags: Flags{
mx: &sync.RWMutex{},
Flags: map[string]Flag{
"headerColor": {
State: "ENABLED",
Expand Down Expand Up @@ -206,6 +210,7 @@ func TestFractionalEvaluation(t *testing.T) {
},
"fallback to default variant if invalid variant as result of fractional evaluation": {
flags: Flags{
mx: &sync.RWMutex{},
Flags: map[string]Flag{
"headerColor": {
State: "ENABLED",
Expand Down Expand Up @@ -240,6 +245,7 @@ func TestFractionalEvaluation(t *testing.T) {
},
"fallback to default variant if percentages don't sum to 100": {
flags: Flags{
mx: &sync.RWMutex{},
Flags: map[string]Flag{
"headerColor": {
State: "ENABLED",
Expand Down
18 changes: 16 additions & 2 deletions pkg/eval/json_evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"regexp"
"strconv"
"strings"
mxSync "sync"

"github.com/open-feature/flagd/pkg/sync"

Expand Down Expand Up @@ -47,6 +48,7 @@ func NewJSONEvaluator(logger *logger.Logger) *JSONEvaluator {
),
state: Flags{
Flags: map[string]Flag{},
mx: &mxSync.RWMutex{},
},
}
jsonlogic.AddOperator("fractionalEvaluation", ev.fractionalEvaluation)
Expand Down Expand Up @@ -110,6 +112,8 @@ func (je *JSONEvaluator) ResolveAllValues(reqID string, context *structpb.Struct
var variant string
var reason string
var err error
je.state.mx.RLock()
defer je.state.mx.RUnlock()
for flagKey, flag := range je.state.Flags {
defaultValue := flag.Variants[flag.DefaultVariant]
switch defaultValue.(type) {
Expand Down Expand Up @@ -161,6 +165,8 @@ func (je *JSONEvaluator) ResolveBooleanValue(reqID string, flagKey string, conte
reason string,
err error,
) {
je.state.mx.RLock()
defer je.state.mx.RUnlock()
je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating boolean flag: %s", flagKey))
return resolve[bool](reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants)
}
Expand All @@ -171,6 +177,8 @@ func (je *JSONEvaluator) ResolveStringValue(reqID string, flagKey string, contex
reason string,
err error,
) {
je.state.mx.RLock()
defer je.state.mx.RUnlock()
je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating string flag: %s", flagKey))
return resolve[string](reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants)
}
Expand All @@ -181,6 +189,8 @@ func (je *JSONEvaluator) ResolveFloatValue(reqID string, flagKey string, context
reason string,
err error,
) {
je.state.mx.RLock()
defer je.state.mx.RUnlock()
je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating float flag: %s", flagKey))
value, variant, reason, err = resolve[float64](
reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants)
Expand All @@ -193,6 +203,8 @@ func (je *JSONEvaluator) ResolveIntValue(reqID string, flagKey string, context *
reason string,
err error,
) {
je.state.mx.RLock()
defer je.state.mx.RUnlock()
je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating int flag: %s", flagKey))
var val float64
val, variant, reason, err = resolve[float64](
Expand All @@ -207,6 +219,8 @@ func (je *JSONEvaluator) ResolveObjectValue(reqID string, flagKey string, contex
reason string,
err error,
) {
je.state.mx.RLock()
defer je.state.mx.RUnlock()
je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating object flag: %s", flagKey))
return resolve[map[string]any](reqID, flagKey, context, je.evaluateVariant, je.state.Flags[flagKey].Variants)
}
Expand Down Expand Up @@ -256,7 +270,7 @@ func (je *JSONEvaluator) evaluateVariant(
variant = strings.ReplaceAll(strings.TrimSpace(result.String()), "\"", "")

// if this is a valid variant, return it
if _, ok := je.state.Flags[flagKey].Variants[variant]; ok {
if _, ok := flag.Variants[variant]; ok {
return variant, model.TargetingMatchReason, nil
}

Expand All @@ -266,7 +280,7 @@ func (je *JSONEvaluator) evaluateVariant(
reason = model.StaticReason
}

return je.state.Flags[flagKey].DefaultVariant, reason, nil
return flag.DefaultVariant, reason, nil
}

// configToFlags convert string configurations to flags and store them to pointer newFlags
Expand Down
34 changes: 29 additions & 5 deletions pkg/eval/json_evaluator_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"reflect"
"sync"

"github.com/open-feature/flagd/pkg/logger"
)
Expand All @@ -21,6 +22,7 @@ type Evaluators struct {
}

type Flags struct {
mx *sync.RWMutex
Flags map[string]Flag `json:"flags"`
}

Expand All @@ -29,7 +31,10 @@ func (f Flags) Add(logger *logger.Logger, source string, ff Flags) map[string]in
notifications := map[string]interface{}{}

for k, newFlag := range ff.Flags {
if storedFlag, ok := f.Flags[k]; ok && storedFlag.Source != source {
f.mx.RLock()
storedFlag, ok := f.Flags[k]
f.mx.RUnlock()
if ok && storedFlag.Source != source {
logger.Warn(fmt.Sprintf(
"flag with key %s from source %s already exist, overriding this with flag from source %s",
k,
Expand All @@ -45,7 +50,9 @@ func (f Flags) Add(logger *logger.Logger, source string, ff Flags) map[string]in

// Store the new version of the flag
newFlag.Source = source
f.mx.Lock()
f.Flags[k] = newFlag
f.mx.Unlock()
}

return notifications
Expand All @@ -56,14 +63,18 @@ func (f Flags) Update(logger *logger.Logger, source string, ff Flags) map[string
notifications := map[string]interface{}{}

for k, flag := range ff.Flags {
if storedFlag, ok := f.Flags[k]; !ok {
f.mx.RLock()
storedFlag, ok := f.Flags[k]
f.mx.RUnlock()
if !ok {
logger.Warn(
fmt.Sprintf("failed to update the flag, flag with key %s from source %s does not exisit.",
fmt.Sprintf("failed to update the flag, flag with key %s from source %s does not exist.",
k,
source))

continue
} else if storedFlag.Source != source {
}
if storedFlag.Source != source {
logger.Warn(fmt.Sprintf(
"flag with key %s from source %s already exist, overriding this with flag from source %s",
k,
Expand All @@ -78,7 +89,9 @@ func (f Flags) Update(logger *logger.Logger, source string, ff Flags) map[string
}

flag.Source = source
f.mx.Lock()
f.Flags[k] = flag
f.mx.Unlock()
}

return notifications
Expand All @@ -89,13 +102,18 @@ func (f Flags) Delete(logger *logger.Logger, source string, ff Flags) map[string
notifications := map[string]interface{}{}

for k := range ff.Flags {
if _, ok := f.Flags[k]; ok {
f.mx.RLock()
_, ok := f.Flags[k]
f.mx.RUnlock()
if ok {
notifications[k] = map[string]interface{}{
"type": string(NotificationDelete),
"source": source,
}

f.mx.Lock()
delete(f.Flags, k)
f.mx.Unlock()
} else {
logger.Warn(
fmt.Sprintf("failed to remove flag, flag with key %s from source %s does not exisit.",
Expand All @@ -111,6 +129,7 @@ func (f Flags) Delete(logger *logger.Logger, source string, ff Flags) map[string
func (f Flags) Merge(logger *logger.Logger, source string, ff Flags) map[string]interface{} {
notifications := map[string]interface{}{}

f.mx.Lock()
for k, v := range f.Flags {
if v.Source == source {
if _, ok := ff.Flags[k]; !ok {
Expand All @@ -124,11 +143,14 @@ func (f Flags) Merge(logger *logger.Logger, source string, ff Flags) map[string]
}
}
}
f.mx.Unlock()

for k, newFlag := range ff.Flags {
newFlag.Source = source

f.mx.RLock()
storedFlag, ok := f.Flags[k]
f.mx.RUnlock()
if !ok {
notifications[k] = map[string]interface{}{
"type": string(NotificationCreate),
Expand All @@ -151,8 +173,10 @@ func (f Flags) Merge(logger *logger.Logger, source string, ff Flags) map[string]
}
}

f.mx.Lock()
// Store the new version of the flag
f.Flags[k] = newFlag
f.mx.Unlock()
}

return notifications
Expand Down
Loading

0 comments on commit 93e356b

Please sign in to comment.