Skip to content

Commit

Permalink
feat: improvements to config reloader, 100% coverage (#1933)
Browse files Browse the repository at this point in the history
Increased test coverage of reloader to 100%.

---------

Co-authored-by: Chris Stockton <[email protected]>
  • Loading branch information
cstockton and Chris Stockton authored Feb 4, 2025
1 parent fbbebcc commit 21c2256
Show file tree
Hide file tree
Showing 5 changed files with 456 additions and 45 deletions.
2 changes: 1 addition & 1 deletion hack/coverage.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FAIL=false

for PKG in "crypto"
for PKG in "crypto" "reloader"
do
UNCOVERED_FUNCS=$(go tool cover -func=coverage.out | grep "^github.com/supabase/auth/internal/$PKG/" | grep -v '100.0%$')
UNCOVERED_FUNCS_COUNT=$(echo "$UNCOVERED_FUNCS" | wc -l)
Expand Down
3 changes: 3 additions & 0 deletions internal/reloader/handler_race_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ func TestAtomicHandlerRaces(t *testing.T) {

hr.Store(hrFunc)

// Calling string should be safe
hr.String()

got := hr.load()
_, ok := hrFuncMap[got]
if !ok {
Expand Down
21 changes: 15 additions & 6 deletions internal/reloader/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package reloader

import (
"net/http"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -11,23 +12,31 @@ func TestAtomicHandler(t *testing.T) {
// for ptr identity
type testHandler struct{ http.Handler }

var calls atomic.Int64
hrFn := func() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls.Add(1)
})
}

hrFunc1 := &testHandler{hrFn()}
hrFunc2 := &testHandler{hrFn()}
assert.NotEqual(t, hrFunc1, hrFunc2)

// a new AtomicHandler should be non-nil
hr := NewAtomicHandler(nil)
hr := NewAtomicHandler(hrFunc1)
assert.NotNil(t, hr)
assert.Equal(t, "reloader.AtomicHandler", hr.String())

// should have no stored handler
// should implement http.Handler
{
hrCur := hr.load()
assert.Nil(t, hrCur)
assert.Equal(t, true, hrCur == nil)
v := (http.Handler)(hr)
before := calls.Load()
v.ServeHTTP(nil, nil)
after := calls.Load()
if exp, got := before+1, after; exp != got {
t.Fatalf("exp %v to be %v after handler was called", got, exp)
}
}

// should be non-nil after store
Expand Down
126 changes: 104 additions & 22 deletions internal/reloader/reloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package reloader

import (
"context"
"log"
"errors"
"strings"
"sync"
"time"

"github.com/fsnotify/fsnotify"
Expand All @@ -27,28 +28,24 @@ type Reloader struct {
watchDir string
reloadIval time.Duration
tickerIval time.Duration
watchFn func() (watcher, error)
reloadFn func(dir string) (*conf.GlobalConfiguration, error)
}

func NewReloader(watchDir string) *Reloader {
return &Reloader{
watchDir: watchDir,
reloadIval: reloadInterval,
tickerIval: tickerInterval,
watchFn: newFSWatcher,
reloadFn: defaultReloadFn,
}
}

// reload attempts to create a new *conf.GlobalConfiguration after loading the
// currently configured watchDir.
func (rl *Reloader) reload() (*conf.GlobalConfiguration, error) {
if err := conf.LoadDirectory(rl.watchDir); err != nil {
return nil, err
}

cfg, err := conf.LoadGlobalFromEnv()
if err != nil {
return nil, err
}
return cfg, nil
return rl.reloadFn(rl.watchDir)
}

// reloadCheckAt checks if reloadConfig should be called, returns true if config
Expand All @@ -66,9 +63,10 @@ func (rl *Reloader) reloadCheckAt(at, lastUpdate time.Time) bool {
}

func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error {
wr, err := fsnotify.NewWatcher()
wr, err := rl.watchFn()
if err != nil {
log.Fatal(err)
logrus.WithError(err).Error("reloader: error creating fsnotify Watcher")
return err
}
defer wr.Close()

Expand All @@ -77,7 +75,7 @@ func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error {

// Ignore errors, if watch dir doesn't exist we can add it later.
if err := wr.Add(rl.watchDir); err != nil {
logrus.WithError(err).Error("watch dir failed")
logrus.WithError(err).Error("reloader: error watching config directory")
}

var lastUpdate time.Time
Expand All @@ -92,7 +90,7 @@ func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error {
// scenarios and wr.WatchList() does not grow which aligns with
// the documented behavior.
if err := wr.Add(rl.watchDir); err != nil {
logrus.WithError(err).Error("watch dir failed")
logrus.WithError(err).Error(err)
}

// Check to see if the config is ready to be relaoded.
Expand All @@ -105,17 +103,18 @@ func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error {

cfg, err := rl.reload()
if err != nil {
logrus.WithError(err).Error("config reload failed")
logrus.WithError(err).Error("reloader: error loading config")
continue
}

// Call the callback function with the latest cfg.
fn(cfg)

case evt, ok := <-wr.Events:
case evt, ok := <-wr.Events():
if !ok {
logrus.WithError(err).Error("fsnotify has exited")
return nil
err := errors.New("reloader: fsnotify event channel was closed")
logrus.WithError(err).Error(err)
return err
}

// We only read files ending in .env
Expand All @@ -130,12 +129,95 @@ func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error {
evt.Op.Has(fsnotify.Write):
lastUpdate = time.Now()
}
case err, ok := <-wr.Errors:
case err, ok := <-wr.Errors():
if !ok {
logrus.Error("fsnotify has exited")
return nil
err := errors.New("reloader: fsnotify error channel was closed")
logrus.WithError(err).Error(err)
return err
}
logrus.WithError(err).Error("fsnotify has reported an error")
logrus.WithError(err).Error(
"reloader: fsnotify has reported an error")
}
}
}

func defaultReloadFn(dir string) (*conf.GlobalConfiguration, error) {
if err := conf.LoadDirectory(dir); err != nil {
return nil, err
}

cfg, err := conf.LoadGlobalFromEnv()
if err != nil {
return nil, err
}
return cfg, nil
}

type watcher interface {
Add(path string) error
Close() error
Events() chan fsnotify.Event
Errors() chan error
}

type fsNotifyWatcher struct {
wr *fsnotify.Watcher
}

func newFSWatcher() (watcher, error) {
wr, err := fsnotify.NewWatcher()
return &fsNotifyWatcher{wr}, err
}

func (o *fsNotifyWatcher) Add(path string) error { return o.wr.Add(path) }
func (o *fsNotifyWatcher) Close() error { return o.wr.Close() }
func (o *fsNotifyWatcher) Errors() chan error { return o.wr.Errors }
func (o *fsNotifyWatcher) Events() chan fsnotify.Event { return o.wr.Events }

type mockWatcher struct {
mu sync.Mutex
err error
eventCh chan fsnotify.Event
errorCh chan error
addCh chan string
}

func newMockWatcher(err error) *mockWatcher {
wr := &mockWatcher{
err: err,
eventCh: make(chan fsnotify.Event, 1024),
errorCh: make(chan error, 1024),
addCh: make(chan string, 1024),
}
return wr
}

func (o *mockWatcher) getErr() error {
o.mu.Lock()
defer o.mu.Unlock()
err := o.err
return err
}

func (o *mockWatcher) setErr(err error) {
o.mu.Lock()
defer o.mu.Unlock()
o.err = err
}

func (o *mockWatcher) Add(path string) error {
o.mu.Lock()
defer o.mu.Unlock()
if err := o.err; err != nil {
return err
}

select {
case o.addCh <- path:
default:
}
return nil
}
func (o *mockWatcher) Close() error { return o.getErr() }
func (o *mockWatcher) Events() chan fsnotify.Event { return o.eventCh }
func (o *mockWatcher) Errors() chan error { return o.errorCh }
Loading

0 comments on commit 21c2256

Please sign in to comment.