diff --git a/driver/configuration/provider_viper.go b/driver/configuration/provider_viper.go index 46cf089bb757..3bcda8cf2bb5 100644 --- a/driver/configuration/provider_viper.go +++ b/driver/configuration/provider_viper.go @@ -145,7 +145,13 @@ func (p *ViperProvider) PublicListenOn() string { } func (p *ViperProvider) DSN() string { - if dsn := viperx.GetString(p.l, ViperKeyDSN, ""); len(dsn) > 0 { + dsn := viperx.GetString(p.l, ViperKeyDSN, "") + + if dsn == "memory" { + return "sqlite://mem.db?mode=memory&_fk=true&cache=shared" + } + + if len(dsn) > 0 { return dsn } diff --git a/driver/configuration/provider_viper_test.go b/driver/configuration/provider_viper_test.go index e66cd39c976d..ff75be22cf6b 100644 --- a/driver/configuration/provider_viper_test.go +++ b/driver/configuration/provider_viper_test.go @@ -181,3 +181,57 @@ func TestViperProvider(t *testing.T) { }) }) } + +type InterceptHook struct { + lastEntry *logrus.Entry +} + +func (l InterceptHook) Levels() []logrus.Level { + return []logrus.Level{logrus.FatalLevel} +} + +func (l InterceptHook) Fire(e *logrus.Entry) error { + l.lastEntry = e + return nil +} + +func TestViperProvider_DSN(t *testing.T) { + t.Run("case=dsn: memory", func(t *testing.T) { + viper.Reset() + viper.Set(configuration.ViperKeyDSN, "memory") + + l := logrus.New() + p := configuration.NewViperProvider(l, false) + + assert.Equal(t, "sqlite://mem.db?mode=memory&_fk=true&cache=shared", p.DSN()) + }) + + t.Run("case=dsn: not memory", func(t *testing.T) { + dsn := "sqlite://foo.db?_fk=true" + viper.Reset() + viper.Set(configuration.ViperKeyDSN, dsn) + + l := logrus.New() + p := configuration.NewViperProvider(l, false) + + assert.Equal(t, dsn, p.DSN()) + }) + + t.Run("case=dsn: not set", func(t *testing.T) { + dsn := "" + viper.Reset() + viper.Set(configuration.ViperKeyDSN, dsn) + + l := logrus.New() + p := configuration.NewViperProvider(l, false) + + var exitCode int + l.ExitFunc = func(i int) { + exitCode = i + } + h := InterceptHook{} + l.AddHook(h) + assert.Equal(t, dsn, p.DSN()) + assert.NotEqual(t, 0, exitCode) + }) +} diff --git a/driver/registry.go b/driver/registry.go index 410227d2abdc..7eaf0bafcc3a 100644 --- a/driver/registry.go +++ b/driver/registry.go @@ -1,8 +1,12 @@ package driver import ( - "github.com/go-errors/errors" + "context" + "net/url" + "strings" + "github.com/gorilla/sessions" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/ory/kratos/courier" @@ -108,9 +112,10 @@ type selfServiceStrategy interface { } func NewRegistry(c configuration.Provider) (Registry, error) { - driver, err := dbal.GetDriverFor(c.DSN()) + dsn := c.DSN() + driver, err := dbal.GetDriverFor(dsn) if err != nil { - return nil, err + return nil, errors.WithStack(err) } registry, ok := driver.(Registry) @@ -118,5 +123,19 @@ func NewRegistry(c configuration.Provider) (Registry, error) { return nil, errors.Errorf("driver of type %T does not implement interface Registry", driver) } + // if dsn is memory we have to run the migrations on every start + if urlParts := strings.SplitN(dsn, "?", 1); len(urlParts) == 2 && strings.HasPrefix(dsn, "sqlite://") { + queryVals, err := url.ParseQuery(urlParts[1]) + if err != nil { + return nil, errors.WithMessage(errors.WithStack(err), "unable to parse the DSN url") + } + if queryVals.Get("mode") == "memory" { + registry.Logger().Print("Kratos is running migrations on every startup as DSN is memory.\n") + registry.Logger().Print("This means your data is lost when Kratos terminates.\n") + if err := registry.Persister().MigrateUp(context.Background()); err != nil { + return nil, err + } + } + } return registry, nil }