diff --git a/cmd/flipt/main.go b/cmd/flipt/main.go index c59786ff5f..b9cd2e5509 100644 --- a/cmd/flipt/main.go +++ b/cmd/flipt/main.go @@ -29,6 +29,7 @@ import ( "github.com/phyber/negroni-gzip/gzip" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/cobra" + "go.flipt.io/flipt/internal/cleanup" "go.flipt.io/flipt/internal/config" "go.flipt.io/flipt/internal/info" "go.flipt.io/flipt/internal/server" @@ -41,6 +42,7 @@ import ( "go.flipt.io/flipt/internal/storage" authstorage "go.flipt.io/flipt/internal/storage/auth" authsql "go.flipt.io/flipt/internal/storage/auth/sql" + oplocksql "go.flipt.io/flipt/internal/storage/oplock/sql" "go.flipt.io/flipt/internal/storage/sql" "go.flipt.io/flipt/internal/storage/sql/mysql" "go.flipt.io/flipt/internal/storage/sql/postgres" @@ -455,7 +457,21 @@ func run(ctx context.Context, logger *zap.Logger) error { otelgrpc.UnaryServerInterceptor(), } - authenticationStore := authsql.NewStore(driver, sql.BuilderFor(db, driver), logger) + var ( + sqlBuilder = sql.BuilderFor(db, driver) + authenticationStore = authsql.NewStore(driver, sqlBuilder, logger) + operationLockService = oplocksql.New(logger, driver, sqlBuilder) + ) + + if cfg.Authentication.ShouldRunCleanup() { + cleanupAuthService := cleanup.NewAuthenticationService(logger, operationLockService, authenticationStore, cfg.Authentication) + cleanupAuthService.Run(ctx) + + shutdownFuncs = append(shutdownFuncs, func(context.Context) { + _ = cleanupAuthService.Stop() + logger.Info("cleanup service has been shutdown") + }) + } // only enable enforcement middleware if authentication required if cfg.Authentication.Required { diff --git a/config/migrations/cockroachdb/2_create_table_operation_lock.down.sql b/config/migrations/cockroachdb/2_create_table_operation_lock.down.sql new file mode 100644 index 0000000000..54feca6d3d --- /dev/null +++ b/config/migrations/cockroachdb/2_create_table_operation_lock.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS operation_lock; diff --git a/config/migrations/cockroachdb/2_create_table_operation_lock.up.sql b/config/migrations/cockroachdb/2_create_table_operation_lock.up.sql new file mode 100644 index 0000000000..ca9392a10a --- /dev/null +++ b/config/migrations/cockroachdb/2_create_table_operation_lock.up.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS operation_lock ( + operation VARCHAR(255) PRIMARY KEY UNIQUE NOT NULL, + version INTEGER DEFAULT 0 NOT NULL, + last_acquired_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + acquired_until TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); diff --git a/config/migrations/mysql/3_create_table_operation_lock.down.sql b/config/migrations/mysql/3_create_table_operation_lock.down.sql new file mode 100644 index 0000000000..54feca6d3d --- /dev/null +++ b/config/migrations/mysql/3_create_table_operation_lock.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS operation_lock; diff --git a/config/migrations/mysql/3_create_table_operation_lock.up.sql b/config/migrations/mysql/3_create_table_operation_lock.up.sql new file mode 100644 index 0000000000..372fb6116f --- /dev/null +++ b/config/migrations/mysql/3_create_table_operation_lock.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS operation_lock ( + operation VARCHAR(255) UNIQUE NOT NULL, + version INTEGER DEFAULT 0 NOT NULL, + last_acquired_at TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP(6), + acquired_until TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP(6), + PRIMARY KEY (`operation`) +); diff --git a/config/migrations/postgres/5_create_table_operation_lock.down.sql b/config/migrations/postgres/5_create_table_operation_lock.down.sql new file mode 100644 index 0000000000..54feca6d3d --- /dev/null +++ b/config/migrations/postgres/5_create_table_operation_lock.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS operation_lock; diff --git a/config/migrations/postgres/5_create_table_operation_lock.up.sql b/config/migrations/postgres/5_create_table_operation_lock.up.sql new file mode 100644 index 0000000000..ca9392a10a --- /dev/null +++ b/config/migrations/postgres/5_create_table_operation_lock.up.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS operation_lock ( + operation VARCHAR(255) PRIMARY KEY UNIQUE NOT NULL, + version INTEGER DEFAULT 0 NOT NULL, + last_acquired_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + acquired_until TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); diff --git a/config/migrations/sqlite3/5_create_table_operation_lock.down.sql b/config/migrations/sqlite3/5_create_table_operation_lock.down.sql new file mode 100644 index 0000000000..54feca6d3d --- /dev/null +++ b/config/migrations/sqlite3/5_create_table_operation_lock.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS operation_lock; diff --git a/config/migrations/sqlite3/5_create_table_operation_lock.up.sql b/config/migrations/sqlite3/5_create_table_operation_lock.up.sql new file mode 100644 index 0000000000..ca9392a10a --- /dev/null +++ b/config/migrations/sqlite3/5_create_table_operation_lock.up.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS operation_lock ( + operation VARCHAR(255) PRIMARY KEY UNIQUE NOT NULL, + version INTEGER DEFAULT 0 NOT NULL, + last_acquired_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + acquired_until TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); diff --git a/errors/errors.go b/errors/errors.go index 7434e5533b..50d175418a 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -5,6 +5,12 @@ import ( "fmt" ) +// As is a utility for one-lining errors.As statements. +// e.g. cerr, match := errors.As[MyCustomError](err). +func As[E error](err error) (e E, _ bool) { + return e, errors.As(err, &e) +} + // New creates a new error with errors.New func New(s string) error { return errors.New(s) diff --git a/internal/cleanup/cleanup.go b/internal/cleanup/cleanup.go new file mode 100644 index 0000000000..16c559dcaf --- /dev/null +++ b/internal/cleanup/cleanup.go @@ -0,0 +1,114 @@ +package cleanup + +import ( + "context" + "fmt" + "time" + + "go.flipt.io/flipt/internal/config" + authstorage "go.flipt.io/flipt/internal/storage/auth" + "go.flipt.io/flipt/internal/storage/oplock" + "go.flipt.io/flipt/rpc/flipt/auth" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +const minCleanupInterval = 5 * time.Minute + +// AuthenticationService is configured to run background goroutines which +// will clear out expired authentication tokens. +type AuthenticationService struct { + logger *zap.Logger + lock oplock.Service + store authstorage.Store + config config.AuthenticationConfig + + errgroup errgroup.Group + cancel func() +} + +// NewAuthenticationService constructs and configures a new instance of authentication service. +func NewAuthenticationService(logger *zap.Logger, lock oplock.Service, store authstorage.Store, config config.AuthenticationConfig) *AuthenticationService { + return &AuthenticationService{ + logger: logger, + lock: lock, + store: store, + config: config, + cancel: func() {}, + } +} + +func (s *AuthenticationService) schedules() map[auth.Method]config.AuthenticationCleanupSchedule { + schedules := map[auth.Method]config.AuthenticationCleanupSchedule{} + if s.config.Methods.Token.Cleanup != nil { + schedules[auth.Method_METHOD_TOKEN] = *s.config.Methods.Token.Cleanup + } + + return schedules +} + +// Run starts up a background goroutine per configure authentication method schedule. +func (s *AuthenticationService) Run(ctx context.Context) { + ctx, s.cancel = context.WithCancel(ctx) + + for method, schedule := range s.schedules() { + var ( + method = method + schedule = schedule + operation = oplock.Operation(fmt.Sprintf("cleanup_auth_%s", method)) + ) + + s.errgroup.Go(func() error { + // on the first attempt to run the cleanup authentication service + // we attempt to obtain the lock immediately. If the lock is already + // held the service should return false and return the current acquired + // current timestamp + acquiredUntil := time.Now().UTC() + for { + select { + case <-ctx.Done(): + return nil + case <-time.After(time.Until(acquiredUntil)): + } + + acquired, entry, err := s.lock.TryAcquire(ctx, operation, schedule.Interval) + if err != nil { + // ensure we dont go into hot loop when the operation lock service + // enters an error state by ensuring we sleep for at-least the minimum + // interval. + now := time.Now().UTC() + if acquiredUntil.Before(now) { + acquiredUntil = now.Add(minCleanupInterval) + } + + s.logger.Warn("attempting to acquire lock", zap.Error(err)) + continue + } + + // update the next sleep target to current entries acquired until + acquiredUntil = entry.AcquiredUntil + + if !acquired { + s.logger.Info("cleanup process not acquired", zap.Time("next_attempt", entry.AcquiredUntil)) + continue + } + + expiredBefore := time.Now().UTC().Add(-schedule.GracePeriod) + s.logger.Info("cleanup process deleting authentications", zap.Time("expired_before", expiredBefore)) + if err := s.store.DeleteAuthentications(ctx, authstorage.Delete( + authstorage.WithMethod(method), + authstorage.WithExpiredBefore(expiredBefore), + )); err != nil { + s.logger.Error("attempting to delete expired authentications", zap.Error(err)) + } + } + }) + } +} + +// Stop signals for the cleanup goroutines to cancel and waits for them to finish. +func (s *AuthenticationService) Stop() error { + s.cancel() + + return s.errgroup.Wait() +} diff --git a/internal/cleanup/cleanup_test.go b/internal/cleanup/cleanup_test.go new file mode 100644 index 0000000000..b805018521 --- /dev/null +++ b/internal/cleanup/cleanup_test.go @@ -0,0 +1,95 @@ +package cleanup + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.flipt.io/flipt/internal/config" + authstorage "go.flipt.io/flipt/internal/storage/auth" + inmemauth "go.flipt.io/flipt/internal/storage/auth/memory" + inmemoplock "go.flipt.io/flipt/internal/storage/oplock/memory" + "go.flipt.io/flipt/rpc/flipt/auth" + "go.uber.org/zap/zaptest" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestCleanup(t *testing.T) { + var ( + ctx = context.Background() + logger = zaptest.NewLogger(t) + authstore = inmemauth.NewStore() + lock = inmemoplock.New() + authConfig = config.AuthenticationConfig{ + Methods: config.AuthenticationMethods{ + Token: config.AuthenticationMethodTokenConfig{ + Enabled: true, + Cleanup: &config.AuthenticationCleanupSchedule{ + Interval: time.Second, + GracePeriod: 5 * time.Second, + }, + }, + }, + } + ) + + // create an initial non-expiring token + clientToken, storedAuth, err := authstore.CreateAuthentication( + ctx, + &authstorage.CreateAuthenticationRequest{Method: auth.Method_METHOD_TOKEN}, + ) + require.NoError(t, err) + + for i := 0; i < 5; i++ { + // run five instances of service + // it should be a safe operation given they share the same lock service + service := NewAuthenticationService(logger, lock, authstore, authConfig) + service.Run(ctx) + defer func() { + require.NoError(t, service.Stop()) + }() + } + + t.Run("ensure non-expiring token exists", func(t *testing.T) { + retrievedAuth, err := authstore.GetAuthenticationByClientToken(ctx, clientToken) + require.NoError(t, err) + assert.Equal(t, storedAuth, retrievedAuth) + }) + + t.Run("create an expiring token and ensure it exists", func(t *testing.T) { + clientToken, storedAuth, err = authstore.CreateAuthentication( + ctx, + &authstorage.CreateAuthenticationRequest{ + Method: auth.Method_METHOD_TOKEN, + ExpiresAt: timestamppb.New(time.Now().UTC().Add(5 * time.Second)), + }, + ) + require.NoError(t, err) + + retrievedAuth, err := authstore.GetAuthenticationByClientToken(ctx, clientToken) + require.NoError(t, err) + assert.Equal(t, storedAuth, retrievedAuth) + }) + + t.Run("ensure grace period protects token from being deleted", func(t *testing.T) { + // token should still exist as it wont be deleted until + // expiry + grace period (5s + 5s == 10s) + time.Sleep(5 * time.Second) + + retrievedAuth, err := authstore.GetAuthenticationByClientToken(ctx, clientToken) + require.NoError(t, err) + assert.Equal(t, storedAuth, retrievedAuth) + + // ensure authentication is expired but still persisted + assert.True(t, retrievedAuth.ExpiresAt.AsTime().Before(time.Now().UTC())) + }) + + t.Run("once expiry and grace period ellapses ensure token is deleted", func(t *testing.T) { + time.Sleep(10 * time.Second) + + _, err := authstore.GetAuthenticationByClientToken(ctx, clientToken) + require.Error(t, err, "resource not found") + }) +} diff --git a/internal/config/authentication.go b/internal/config/authentication.go index f03dfa987c..fb6753eb21 100644 --- a/internal/config/authentication.go +++ b/internal/config/authentication.go @@ -1,8 +1,28 @@ package config -import "github.com/spf13/viper" +import ( + "strings" + "time" -var _ defaulter = (*AuthenticationConfig)(nil) + "github.com/spf13/viper" + "go.flipt.io/flipt/rpc/flipt/auth" +) + +var ( + _ defaulter = (*AuthenticationConfig)(nil) + stringToAuthMethod = map[string]auth.Method{} +) + +func init() { + for method, v := range auth.Method_value { + if auth.Method(v) == auth.Method_METHOD_NONE { + continue + } + + name := strings.ToLower(strings.TrimPrefix(method, "METHOD_")) + stringToAuthMethod[name] = auth.Method(v) + } +} // AuthenticationConfig configures Flipts authentication mechanisms type AuthenticationConfig struct { @@ -11,24 +31,69 @@ type AuthenticationConfig struct { // Else, authentication is not required and Flipt's APIs are not secured. Required bool `json:"required,omitempty" mapstructure:"required"` - Methods struct { - Token AuthenticationMethodTokenConfig `json:"token,omitempty" mapstructure:"token"` - } `json:"methods,omitempty" mapstructure:"methods"` + Methods AuthenticationMethods `json:"methods,omitempty" mapstructure:"methods"` } -func (a *AuthenticationConfig) setDefaults(v *viper.Viper) []string { +// ShouldRunCleanup returns true if the cleanup background process should be started. +// It returns true given at-least 1 method is enabled and it's associated schedule +// has been configured (non-nil). +func (c AuthenticationConfig) ShouldRunCleanup() bool { + return (c.Methods.Token.Enabled && c.Methods.Token.Cleanup != nil) +} + +func (c *AuthenticationConfig) setDefaults(v *viper.Viper) []string { + token := map[string]any{ + "enabled": false, + } + + if v.GetBool("authentication.methods.token.enabled") { + token["cleanup"] = map[string]any{ + "interval": time.Hour, + "grace_period": 30 * time.Minute, + } + } + v.SetDefault("authentication", map[string]any{ "required": false, "methods": map[string]any{ - "token": map[string]any{ - "enabled": false, - }, + "token": token, }, }) return nil } +func (c *AuthenticationConfig) validate() error { + for _, cleanup := range []struct { + name string + schedule *AuthenticationCleanupSchedule + }{ + // add additional schedules as token methods are created + {"token", c.Methods.Token.Cleanup}, + } { + if cleanup.schedule == nil { + continue + } + + field := "authentication.method" + cleanup.name + if cleanup.schedule.Interval <= 0 { + return errFieldWrap(field+".cleanup.interval", errPositiveNonZeroDuration) + } + + if cleanup.schedule.GracePeriod <= 0 { + return errFieldWrap(field+".cleanup.grace_period", errPositiveNonZeroDuration) + } + } + + return nil +} + +// AuthenticationMethods is a set of configuration for each authentication +// method available for use within Flipt. +type AuthenticationMethods struct { + Token AuthenticationMethodTokenConfig `json:"token,omitempty" mapstructure:"token"` +} + // AuthenticationMethodTokenConfig contains fields used to configure the authentication // method "token". // This authentication method supports the ability to create static tokens via the @@ -36,5 +101,12 @@ func (a *AuthenticationConfig) setDefaults(v *viper.Viper) []string { type AuthenticationMethodTokenConfig struct { // Enabled designates whether or not static token authentication is enabled // and whether Flipt will mount the "token" method APIs. - Enabled bool `json:"enabled,omitempty" mapstructure:"enabled"` + Enabled bool `json:"enabled,omitempty" mapstructure:"enabled"` + Cleanup *AuthenticationCleanupSchedule `json:"cleanup,omitempty" mapstructure:"cleanup"` +} + +// AuthenticationCleanupSchedule is used to configure a cleanup goroutine. +type AuthenticationCleanupSchedule struct { + Interval time.Duration `json:"interval,omitempty" mapstructure:"interval"` + GracePeriod time.Duration `json:"gracePeriod,omitempty" mapstructure:"grace_period"` } diff --git a/internal/config/config.go b/internal/config/config.go index 6d00496dc7..77c39f7ba8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,6 +19,7 @@ var decodeHooks = mapstructure.ComposeDecodeHookFunc( stringToEnumHookFunc(stringToCacheBackend), stringToEnumHookFunc(stringToScheme), stringToEnumHookFunc(stringToDatabaseProtocol), + stringToEnumHookFunc(stringToAuthMethod), ) // Config contains all of Flipts configuration needs. @@ -134,8 +135,8 @@ func bindEnvVars(v *viper.Viper, prefix string, field reflect.StructField) { // descend into struct fields if typ.Kind() == reflect.Struct { - for i := 0; i < field.Type.NumField(); i++ { - structField := field.Type.Field(i) + for i := 0; i < typ.NumField(); i++ { + structField := typ.Field(i) // key becomes prefix for sub-fields bindEnvVars(v, key+".", structField) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 1903cdfb32..6ed1c9478d 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -352,6 +352,16 @@ func TestLoad(t *testing.T) { path: "./testdata/database/missing_name.yml", wantErr: errValidationRequired, }, + { + name: "authentication - negative interval", + path: "./testdata/authentication/negative_interval.yml", + wantErr: errPositiveNonZeroDuration, + }, + { + name: "authentication - zero grace_period", + path: "./testdata/authentication/zero_grace_period.yml", + wantErr: errPositiveNonZeroDuration, + }, { name: "advanced", path: "./testdata/advanced.yml", @@ -402,6 +412,18 @@ func TestLoad(t *testing.T) { CheckForUpdates: false, TelemetryEnabled: false, } + cfg.Authentication = AuthenticationConfig{ + Required: true, + Methods: AuthenticationMethods{ + Token: AuthenticationMethodTokenConfig{ + Enabled: true, + Cleanup: &AuthenticationCleanupSchedule{ + Interval: 2 * time.Hour, + GracePeriod: 48 * time.Hour, + }, + }, + }, + } return cfg }, }, diff --git a/internal/config/errors.go b/internal/config/errors.go index f2af2e3cbd..8716850b1f 100644 --- a/internal/config/errors.go +++ b/internal/config/errors.go @@ -11,6 +11,8 @@ var ( // errValidationRequired is returned when a required value is // either not supplied or supplied with empty value. errValidationRequired = errors.New("non-empty value is required") + // errPositiveNonZeroDuration is returned when a negative or zero time.Duration is provided. + errPositiveNonZeroDuration = errors.New("positive non-zero duration required") ) func errFieldWrap(field string, err error) error { diff --git a/internal/config/testdata/advanced.yml b/internal/config/testdata/advanced.yml index 4279e19a5c..24eae1f510 100644 --- a/internal/config/testdata/advanced.yml +++ b/internal/config/testdata/advanced.yml @@ -39,3 +39,12 @@ db: meta: check_for_updates: false telemetry_enabled: false + +authentication: + required: true + methods: + token: + enabled: true + cleanup: + interval: 2h + grace_period: 48h diff --git a/internal/config/testdata/authentication/negative_interval.yml b/internal/config/testdata/authentication/negative_interval.yml new file mode 100644 index 0000000000..e400dc8ec5 --- /dev/null +++ b/internal/config/testdata/authentication/negative_interval.yml @@ -0,0 +1,5 @@ +authentication: + methods: + token: + cleanup: + interval: -1m diff --git a/internal/config/testdata/authentication/zero_grace_period.yml b/internal/config/testdata/authentication/zero_grace_period.yml new file mode 100644 index 0000000000..57c7210dbe --- /dev/null +++ b/internal/config/testdata/authentication/zero_grace_period.yml @@ -0,0 +1,5 @@ +authentication: + methods: + token: + cleanup: + grace_period: 0 diff --git a/internal/storage/auth/memory/store.go b/internal/storage/auth/memory/store.go index cace50cb9d..9e58659914 100644 --- a/internal/storage/auth/memory/store.go +++ b/internal/storage/auth/memory/store.go @@ -200,7 +200,8 @@ func (s *Store) DeleteAuthentications(_ context.Context, req *auth.DeleteAuthent for hashedToken, a := range s.byToken { if (req.ID == nil || *req.ID == a.Id) && (req.Method == nil || *req.Method == a.Method) && - (req.ExpiredBefore == nil || a.ExpiresAt.AsTime().Before(req.ExpiredBefore.AsTime())) { + (req.ExpiredBefore == nil || + (a.ExpiresAt != nil && a.ExpiresAt.AsTime().Before(req.ExpiredBefore.AsTime()))) { delete(s.byID, a.Id) delete(s.byToken, hashedToken) } diff --git a/internal/storage/auth/testing/testing.go b/internal/storage/auth/testing/testing.go index eb303ca58a..872d41bf1b 100644 --- a/internal/storage/auth/testing/testing.go +++ b/internal/storage/auth/testing/testing.go @@ -42,10 +42,16 @@ func TestAuthenticationStoreHarness(t *testing.T, fn func(t *testing.T) storagea t.Run(fmt.Sprintf("Create %d authentications", len(created)), func(t *testing.T) { uniqueTokens := make(map[string]struct{}, len(created)) for i := 0; i < len(created); i++ { + // the first token will have a null expiration + var expires *timestamppb.Timestamp + if i > 0 { + expires = timestamppb.New(time.Unix(int64(i+1), 0)) + } + token, auth, err := store.CreateAuthentication(ctx, &storageauth.CreateAuthenticationRequest{ Method: auth.Method_METHOD_TOKEN, // from t1 to t100. - ExpiresAt: timestamppb.New(time.Unix(int64(i+1), 0)), + ExpiresAt: expires, Metadata: map[string]string{ "name": fmt.Sprintf("foo_%d", i+1), "description": "bar", @@ -116,11 +122,11 @@ func TestAuthenticationStoreHarness(t *testing.T, fn func(t *testing.T) storagea }) t.Run("Delete a single instance by ID", func(t *testing.T) { - req := storageauth.Delete(storageauth.WithID(created[0].Auth.Id)) + req := storageauth.Delete(storageauth.WithID(created[99].Auth.Id)) err := store.DeleteAuthentications(ctx, req) require.NoError(t, err) - auth, err := store.GetAuthenticationByClientToken(ctx, created[0].Token) + auth, err := store.GetAuthenticationByClientToken(ctx, created[99].Token) var expected errors.ErrNotFound if !assert.ErrorAs(t, err, &expected, "authentication still exists in the database") { t.Log("Auth still exists", auth) @@ -143,8 +149,9 @@ func TestAuthenticationStoreHarness(t *testing.T, fn func(t *testing.T) storagea all, err := storage.ListAll(ctx, store.ListAuthentications, storage.ListAllParams{}) require.NoError(t, err) - // ensure only the most recent 50 expires_at timestamped authentications remain - if !assert.Equal(t, allAuths(created[50:]), all) { + // ensure only the most recent 49 expires_at timestamped authentications remain + // along with the first authentication without an expiry + if !assert.Equal(t, allAuths(append(created[:1], created[50:99]...)), all) { fmt.Println("Found:", len(all)) } }) @@ -165,7 +172,28 @@ func TestAuthenticationStoreHarness(t *testing.T, fn func(t *testing.T) storagea require.NoError(t, err) // ensure only the most recent 25 expires_at timestamped authentications remain - if !assert.Equal(t, allAuths(created[75:]), all) { + if !assert.Equal(t, allAuths(append(created[:1], created[75:99]...)), all) { + fmt.Println("Found:", len(all)) + } + }) + + t.Run("Delete the rest of the tokens with an expiration", func(t *testing.T) { + // all tokens with expiry before t76 + req := storageauth.Delete( + storageauth.WithExpiredBefore(time.Unix(101, 0).UTC()), + ) + + err := store.DeleteAuthentications( + ctx, + req, + ) + require.NoError(t, err) + + all, err := storage.ListAll(ctx, store.ListAuthentications, storage.ListAllParams{}) + require.NoError(t, err) + + // ensure only the the first token with no expiry exists + if !assert.Equal(t, allAuths(created[:1]), all) { fmt.Println("Found:", len(all)) } }) diff --git a/internal/storage/oplock/memory/memory.go b/internal/storage/oplock/memory/memory.go new file mode 100644 index 0000000000..f0e0e3533f --- /dev/null +++ b/internal/storage/oplock/memory/memory.go @@ -0,0 +1,52 @@ +package memory + +import ( + "context" + "sync" + "time" + + "go.flipt.io/flipt/internal/storage/oplock" +) + +// Service is an in-memory implementation of the oplock.Service. +// It is only safe for single instance / in-process use. +type Service struct { + mu sync.Mutex + + ops map[oplock.Operation]oplock.LockEntry +} + +// New constructs and configures a new service instance. +func New() *Service { + return &Service{ops: map[oplock.Operation]oplock.LockEntry{}} +} + +// TryAcquire will attempt to obtain a lock for the supplied operation name for the specified duration. +// If it succeeds then the returned boolean (acquired) will be true, else false. +// The lock entry associated with the last successful acquisition is also returned. +// Given the lock was acquired successfully this will be the entry just created. +func (s *Service) TryAcquire(ctx context.Context, operation oplock.Operation, duration time.Duration) (acquired bool, entry oplock.LockEntry, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now().UTC() + entry, ok := s.ops[operation] + if !ok { + entry.Operation = operation + entry.Version = 1 + entry.LastAcquired = now + entry.AcquiredUntil = now.Add(duration) + s.ops[operation] = entry + return true, entry, nil + } + + if entry.AcquiredUntil.Before(now) { + entry.Version++ + entry.LastAcquired = now + entry.AcquiredUntil = now.Add(duration) + s.ops[operation] = entry + return true, entry, nil + } + + return false, entry, nil +} diff --git a/internal/storage/oplock/memory/memory_test.go b/internal/storage/oplock/memory/memory_test.go new file mode 100644 index 0000000000..1747dc72f9 --- /dev/null +++ b/internal/storage/oplock/memory/memory_test.go @@ -0,0 +1,11 @@ +package memory + +import ( + "testing" + + oplocktesting "go.flipt.io/flipt/internal/storage/oplock/testing" +) + +func Test_Harness(t *testing.T) { + oplocktesting.Harness(t, New()) +} diff --git a/internal/storage/oplock/oplock.go b/internal/storage/oplock/oplock.go new file mode 100644 index 0000000000..70817405e8 --- /dev/null +++ b/internal/storage/oplock/oplock.go @@ -0,0 +1,30 @@ +package oplock + +import ( + "context" + "time" +) + +// Operation is a string which identifies a particular unique operation name. +type Operation string + +type LockEntry struct { + Operation Operation + Version int64 + LastAcquired time.Time + AcquiredUntil time.Time +} + +// Service is an operation lock service which provides the ability to lock access +// to perform a named operation up until an ellapsed duration. +// Implementations of this type can be used to ensure an operation occurs once per +// the provided elapsed duration between a set of Flipt instances. +// If coordinating a distributed set of Flipt instances then a remote backend (e.g. SQL) +// will be required. In memory implementations will only work for single instance deployments. +type Service interface { + // TryAcquire will attempt to obtain a lock for the supplied operation name for the specified duration. + // If it succeeds then the returned boolean (acquired) will be true, else false. + // The lock entry associated with the last successful acquisition is also returned. + // Given the lock was acquired successfully this will be the entry just created. + TryAcquire(ctx context.Context, operation Operation, duration time.Duration) (acquired bool, entry LockEntry, err error) +} diff --git a/internal/storage/oplock/sql/sql.go b/internal/storage/oplock/sql/sql.go new file mode 100644 index 0000000000..b03fde7d84 --- /dev/null +++ b/internal/storage/oplock/sql/sql.go @@ -0,0 +1,164 @@ +package memory + +import ( + "context" + "fmt" + "time" + + sq "github.com/Masterminds/squirrel" + "go.flipt.io/flipt/errors" + "go.flipt.io/flipt/internal/storage/oplock" + storagesql "go.flipt.io/flipt/internal/storage/sql" + "go.uber.org/zap" +) + +// Service is an in-memory implementation of the oplock.Service. +// It is only safe for single instance / in-process use. +type Service struct { + logger *zap.Logger + driver storagesql.Driver + builder sq.StatementBuilderType +} + +// New constructs and configures a new service instance. +func New(logger *zap.Logger, driver storagesql.Driver, builder sq.StatementBuilderType) *Service { + return &Service{ + logger: logger, + driver: driver, + builder: builder, + } +} + +// TryAcquire will attempt to obtain a lock for the supplied operation name for the specified duration. +// If it succeeds then the returned boolean (acquired) will be true, else false. +// The lock entry associated with the last successful acquisition is also returned. +// Given the lock was acquired successfully this will be the entry just created. +func (s *Service) TryAcquire(ctx context.Context, operation oplock.Operation, duration time.Duration) (acquired bool, entry oplock.LockEntry, err error) { + entry, err = s.readEntry(ctx, operation) + if err != nil { + if _, match := errors.As[errors.ErrNotFound](err); match { + // entry does not exist so we try and create one + entry, err := s.insertEntry(ctx, operation, duration) + if err != nil { + if _, match := errors.As[errors.ErrInvalid](err); match { + // check if the entry is invalid due to + // uniqueness constraint violation + // if so re-read the current entry and return that + entry, err := s.readEntry(ctx, operation) + return false, entry, err + } + + return false, entry, err + } + + return true, entry, nil + } + + // something went wrong + return false, entry, err + } + + // entry exists so first check the acquired until has elapsed + if time.Now().UTC().Before(entry.AcquiredUntil) { + // return early as the lock is still acquired + return false, entry, nil + } + + acquired, err = s.acquireEntry(ctx, &entry, duration) + + return acquired, entry, err +} + +func (s *Service) acquireEntry(ctx context.Context, entry *oplock.LockEntry, dur time.Duration) (acquired bool, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("updating existing entry: %w", s.driver.AdaptError(err)) + } + }() + + now := time.Now().UTC() + query := s.builder.Update("operation_lock"). + Set("version", entry.Version+1). + Set("last_acquired_at", now). + Set("acquired_until", now.Add(dur)). + Where(sq.Eq{ + "operation": string(entry.Operation), + // ensure current entry has not been updated + "version": entry.Version, + }) + + res, err := query.ExecContext(ctx) + if err != nil { + return false, err + } + + count, err := res.RowsAffected() + if err != nil { + return false, err + } + + if count < 1 { + // current entry version does not match + // therefore we can assume it was updated + // by concurrent lock acquirer + return false, nil + } + + entry.Version++ + entry.LastAcquired = now + entry.AcquiredUntil = now.Add(dur) + return true, nil +} + +func (s *Service) insertEntry(ctx context.Context, op oplock.Operation, dur time.Duration) (entry oplock.LockEntry, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("inserting new entry: %w", err) + } + }() + + entry.Operation = op + entry.Version = 1 + entry.LastAcquired = time.Now().UTC() + entry.AcquiredUntil = entry.LastAcquired.Add(dur) + + _, err = s.builder.Insert("operation_lock"). + Columns( + "operation", + "version", + "last_acquired_at", + "acquired_until", + ).Values( + &entry.Operation, + &entry.Version, + &entry.LastAcquired, + &entry.AcquiredUntil, + ).ExecContext(ctx) + + return entry, s.driver.AdaptError(err) +} + +func (s *Service) readEntry(ctx context.Context, operation oplock.Operation) (entry oplock.LockEntry, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("reading entry: %w", err) + } + }() + + err = s.builder.Select( + "operation", + "version", + "last_acquired_at", + "acquired_until", + ).From("operation_lock"). + Where(sq.Eq{"operation": string(operation)}). + QueryRowContext(ctx). + Scan( + &entry.Operation, + &entry.Version, + &entry.LastAcquired, + &entry.AcquiredUntil, + ) + + return entry, s.driver.AdaptError(err) +} diff --git a/internal/storage/oplock/sql/sql_test.go b/internal/storage/oplock/sql/sql_test.go new file mode 100644 index 0000000000..3ed7defe02 --- /dev/null +++ b/internal/storage/oplock/sql/sql_test.go @@ -0,0 +1,26 @@ +package memory + +import ( + "testing" + + oplocktesting "go.flipt.io/flipt/internal/storage/oplock/testing" + storagesql "go.flipt.io/flipt/internal/storage/sql" + sqltesting "go.flipt.io/flipt/internal/storage/sql/testing" + "go.uber.org/zap/zaptest" +) + +func Test_Harness(t *testing.T) { + logger := zaptest.NewLogger(t) + db, err := sqltesting.Open() + if err != nil { + t.Fatal(err) + } + + oplocktesting.Harness( + t, + New( + logger, + db.Driver, + storagesql.BuilderFor(db.DB, db.Driver), + )) +} diff --git a/internal/storage/oplock/testing/testing.go b/internal/storage/oplock/testing/testing.go new file mode 100644 index 0000000000..1924976acd --- /dev/null +++ b/internal/storage/oplock/testing/testing.go @@ -0,0 +1,103 @@ +package testing + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.flipt.io/flipt/internal/storage/oplock" + "golang.org/x/sync/errgroup" +) + +// Harness is a test harness for all implementations of oplock.Service. +// The test consists of firing multiple goroutines which attempt to acquire +// a lock over a single operation "test". +// Each acquisitions timestamp is pushed down a channel. +// When five lock acquisitions have occurred the test ensures that it took +// at-least a specified duration to do so (interval * (iterations - 1)). +// Also that acquisitions occurred in ascending timestamp order with a delta +// between each tick of at-least the configured interval. +func Harness(t *testing.T, s oplock.Service) { + var ( + acquiredAt = make(chan time.Time, 1) + interval = 2 * time.Second + op = oplock.Operation("test") + ctx, cancel = context.WithCancel(context.Background()) + ) + + errgroup, ctx := errgroup.WithContext(ctx) + + for i := 0; i < 5; i++ { + var acquiredUntil = time.Now().UTC() + + errgroup.Go(func() error { + for { + select { + case <-ctx.Done(): + return nil + case <-time.After(time.Until(acquiredUntil)): + } + + acquired, entry, err := s.TryAcquire(ctx, op, interval) + if err != nil { + return err + } + + if acquired { + acquiredAt <- entry.LastAcquired + } + + acquiredUntil = entry.AcquiredUntil + } + }) + } + + now := time.Now().UTC() + var acquisitions []time.Time + for tick := range acquiredAt { + acquisitions = append(acquisitions, tick) + + if len(acquisitions) == 5 { + break + } + } + + since := time.Since(now) + + // ensure it took at-least 8s second to acquire 5 locks + require.Greater(t, since, 8*time.Second) + + t.Logf("It took %s to consume the lock 5 times with an interval of %s\n", since, interval) + + cancel() + + if err := errgroup.Wait(); err != nil { + // there are a couple acceptable errors here (context.Canceled and "operation was canceled") + // stdlib net package can adapt context.Canceled into an unexported errCanceled + // https://github.com/golang/go/blob/6b45863e47ad1a27ba3051ce0407f0bdc7b46113/src/net/net.go#L428-L439 + switch { + case errors.Is(err, context.Canceled): + case strings.Contains(err.Error(), "operation was canceled"): + default: + require.FailNowf(t, "error not as expected", "%v", err) + } + } + + close(acquiredAt) + + // ensure ticks were acquired sequentially + assert.IsIncreasing(t, acquisitions) + + for i, tick := range acquisitions { + if len(acquisitions) == i+1 { + break + } + + // tick at T(n+1) occurs at-least after T(n) + assert.Greater(t, acquisitions[i+1].Sub(tick), interval) + } +} diff --git a/internal/storage/sql/migrator.go b/internal/storage/sql/migrator.go index 09f075892b..595f3714a6 100644 --- a/internal/storage/sql/migrator.go +++ b/internal/storage/sql/migrator.go @@ -17,10 +17,10 @@ import ( ) var expectedVersions = map[Driver]uint{ - SQLite: 4, - Postgres: 4, - MySQL: 2, - CockroachDB: 1, + SQLite: 5, + Postgres: 5, + MySQL: 3, + CockroachDB: 2, } // Migrator is responsible for migrating the database schema diff --git a/internal/storage/sql/testing/testing.go b/internal/storage/sql/testing/testing.go index ecfcb00f48..e304062ce2 100644 --- a/internal/storage/sql/testing/testing.go +++ b/internal/storage/sql/testing/testing.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "os" + "time" "github.com/docker/go-connections/nat" "github.com/golang-migrate/migrate/v4" @@ -179,6 +180,16 @@ func Open() (*Database, error) { } } + db.SetConnMaxLifetime(2 * time.Minute) + db.SetConnMaxIdleTime(time.Minute) + + // 2 minute timeout attempting to establish first connection + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + if err := db.PingContext(ctx); err != nil { + return nil, err + } + return &Database{ DB: db, Driver: driver,