Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add collision detection, enabled by default #43

Merged
merged 4 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion xload/async.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,31 @@ import (
type loadAndSet func(context.Context, reflect.Value) error
type loadAndSetPointer func(context.Context, reflect.Value, reflect.Value, bool) error

func processConcurrently(ctx context.Context, v any, opts *options) error {
func processConcurrently(ctx context.Context, v any, o *options) error {
if !o.detectCollisions {
return doProcessConcurrently(ctx, v, o)
}

syncKeyUsage := &collisionSyncMap{}
ldr := o.loader
o.loader = LoaderFunc(func(ctx context.Context, key string) (string, error) {
v, err := ldr.Load(ctx, key)

if err == nil {
syncKeyUsage.add(key)
}

return v, err
})

if err := doProcessConcurrently(ctx, v, o); err != nil {
return err
}

return syncKeyUsage.err()
}

func doProcessConcurrently(ctx context.Context, v any, opts *options) error {
doneCh := make(chan struct{}, 1)
defer close(doneCh)

Expand Down
66 changes: 66 additions & 0 deletions xload/collision.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package xload

import (
"sort"
"sync"
)

type collisionSyncMap sync.Map

func (cm *collisionSyncMap) add(key string) {
m := (*sync.Map)(cm)
v, loaded := m.LoadOrStore(key, 1)

if loaded {
m.Store(key, v.(int)+1)
}
}

func (cm *collisionSyncMap) err() error {
var collidedKeys []string

m := (*sync.Map)(cm)
m.Range(func(key, v any) bool {
if key == "" {
return true
}

if count, _ := v.(int); count > 1 {
collidedKeys = append(collidedKeys, key.(string))
}

return true
})

return keysToErr(collidedKeys)
}

type collisionMap map[string]int

func (cm collisionMap) add(key string) { cm[key]++ }

func (cm collisionMap) err() error {
var collidedKeys []string

for key, count := range cm {
if key == "" {
continue
}

if count > 1 {
collidedKeys = append(collidedKeys, key)
}
}

return keysToErr(collidedKeys)
}

func keysToErr(collidedKeys []string) error {
if len(collidedKeys) == 0 {
return nil
}

sort.Strings(collidedKeys)

return &ErrCollision{keys: collidedKeys}
}
33 changes: 33 additions & 0 deletions xload/collision_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package xload

import (
"testing"

"github.com/stretchr/testify/assert"
)

func Test_collisionSyncMap_err(t *testing.T) {
tests := []struct {
name string
cm func() *collisionSyncMap
wantErr assert.ErrorAssertionFunc
}{
{
name: "empty keys",
cm: func() *collisionSyncMap {
m := &collisionSyncMap{}
m.add("")
m.add("")
m.add("")
return m
},
wantErr: assert.NoError,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.wantErr(t, tt.cm().err(), "err()")
})
}
}
8 changes: 8 additions & 0 deletions xload/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,11 @@ type ErrInvalidPrefixAndKey struct {
func (e ErrInvalidPrefixAndKey) Error() string {
return fmt.Sprintf("`%s` key=%s has both prefix and key", e.field, e.key)
}

// ErrCollision is returned when key collisions are detected.
// Collision can happen when two or more fields have the same full key.
type ErrCollision struct{ keys []string }
ajatprabha marked this conversation as resolved.
Show resolved Hide resolved

func (e *ErrCollision) Error() string {
return fmt.Sprintf("xload: key collisions detected for keys: %v", e.keys)
}
29 changes: 26 additions & 3 deletions xload/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,34 @@ func Load(ctx context.Context, v any, opts ...Option) error {
return processConcurrently(ctx, v, o)
}

return process(ctx, v, o.tagName, o.loader)
return process(ctx, v, o)
}

func process(ctx context.Context, v any, o *options) error {
if !o.detectCollisions {
return doProcess(ctx, v, o.tagName, o.loader)
}

keyUsage := make(collisionMap)
loaderWithKeyUsage := LoaderFunc(func(ctx context.Context, key string) (string, error) {
v, err := o.loader.Load(ctx, key)

if err == nil {
keyUsage.add(key)
}

return v, err
})

if err := doProcess(ctx, v, o.tagName, loaderWithKeyUsage); err != nil {
return err
}

return keyUsage.err()
}

//nolint:funlen,nestif
func process(ctx context.Context, obj any, tagKey string, loader Loader) error {
func doProcess(ctx context.Context, obj any, tagKey string, loader Loader) error {
v := reflect.ValueOf(obj)

if v.Kind() != reflect.Ptr {
Expand Down Expand Up @@ -141,7 +164,7 @@ func process(ctx context.Context, obj any, tagKey string, loader Loader) error {
pld = PrefixLoader(meta.prefix, loader)
}

err := process(ctx, fVal.Interface(), tagKey, pld)
err := doProcess(ctx, fVal.Interface(), tagKey, pld)
if err != nil {
return err
}
Expand Down
52 changes: 52 additions & 0 deletions xload/load_struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,58 @@ func TestLoad_Structs(t *testing.T) {
err: &ErrInvalidPrefixAndKey{field: "Address", key: "ADDRESS"},
loader: MapLoader{},
},

// key collision
{
name: "key collision",
input: &struct {
Address1 Address `env:",prefix=ADDRESS_"`
Address2 *Address `env:",prefix=ADDRESS_"`
}{},
err: &ErrCollision{keys: []string{
"ADDRESS_CITY",
"ADDRESS_LATITUDE",
"ADDRESS_LONGITUTE",
"ADDRESS_STREET",
}},
loader: MapLoader{
"ADDRESS_STREET": "street1",
"ADDRESS_CITY": "city1",
"ADDRESS_LONGITUTE": "1.1",
"ADDRESS_LATITUDE": "-2.2",
},
},
{
name: "key collision with detection disabled",
opts: []Option{SkipCollisionDetection},
input: &struct {
Address1 Address `env:",prefix=ADDRESS_"`
Address2 *Address `env:",prefix=ADDRESS_"`
}{},
want: &struct {
Address1 Address
Address2 *Address
}{
Address{
Street: "street1",
City: "city1",
Longitute: ptr.Float64(1.1),
Latitude: ptr.Float64(-2.2),
},
&Address{
Street: "street1",
City: "city1",
Longitute: ptr.Float64(1.1),
Latitude: ptr.Float64(-2.2),
},
},
loader: MapLoader{
"ADDRESS_STREET": "street1",
"ADDRESS_CITY": "city1",
"ADDRESS_LONGITUTE": "1.1",
"ADDRESS_LATITUDE": "-2.2",
},
},
}

runTestcases(t, testcases)
Expand Down
5 changes: 3 additions & 2 deletions xload/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type testcase struct {
input any
want any
loader Loader
opts []Option
err error
}

Expand Down Expand Up @@ -711,7 +712,7 @@ func runTestcases(t *testing.T, testcases []testcase) {
tc := tc

t.Run("Load_"+tc.name, func(t *testing.T) {
err := Load(context.Background(), tc.input, WithLoader(tc.loader))
err := Load(context.Background(), tc.input, append(tc.opts, WithLoader(tc.loader))...)
if tc.err != nil {
assert.Error(t, err)
assert.ErrorContains(t, err, tc.err.Error())
Expand All @@ -724,7 +725,7 @@ func runTestcases(t *testing.T, testcases []testcase) {
})

t.Run("LoadAsync_"+tc.name, func(t *testing.T) {
err := Load(context.Background(), tc.input, Concurrency(5), WithLoader(tc.loader))
err := Load(context.Background(), tc.input, append(tc.opts, Concurrency(5), WithLoader(tc.loader))...)
if tc.err != nil {
assert.Error(t, err)
assert.ErrorContains(t, err, tc.err.Error())
Expand Down
53 changes: 31 additions & 22 deletions xload/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,48 @@ const defaultKey = "env"
// Option configures the xload behaviour.
type Option interface{ apply(*options) }

// FieldTagName allows customising the struct tag name to use.
type FieldTagName string

func (k FieldTagName) apply(opts *options) { opts.tagName = string(k) }

// Concurrency allows customising the number of goroutines to use.
// Default is 1.
type Concurrency int

func (c Concurrency) apply(opts *options) { opts.concurrency = int(c) }

// WithLoader allows customising the loader to use.
func WithLoader(loader Loader) Option {
return optionFunc(func(opts *options) { opts.loader = loader })
}

// SkipCollisionDetection disables detecting any key collisions while trying to load full keys.
var SkipCollisionDetection = &applier{f: func(o *options) { o.detectCollisions = false }}

// optionFunc allows using a function as an Option.
type optionFunc func(*options)

func (f optionFunc) apply(opts *options) { f(opts) }

type applier struct{ f func(*options) }

func (a *applier) apply(opts *options) { a.f(opts) }

// options holds the configuration.
type options struct {
tagName string
loader Loader
concurrency int
tagName string
loader Loader
concurrency int
detectCollisions bool
}

func newOptions(opts ...Option) *options {
o := &options{
tagName: defaultKey,
loader: OSLoader(),
concurrency: 1,
tagName: defaultKey,
loader: OSLoader(),
concurrency: 1,
detectCollisions: true,
}

for _, opt := range opts {
Expand All @@ -30,19 +55,3 @@ func newOptions(opts ...Option) *options {

return o
}

// FieldTagName allows customising the struct tag name to use.
type FieldTagName string

func (k FieldTagName) apply(opts *options) { opts.tagName = string(k) }

// Concurrency allows customising the number of goroutines to use.
// Default is 1.
type Concurrency int

func (c Concurrency) apply(opts *options) { opts.concurrency = int(c) }

// WithLoader allows customising the loader to use.
func WithLoader(loader Loader) Option {
return optionFunc(func(opts *options) { opts.loader = loader })
}
Loading
Loading