diff --git a/courier/courier.go b/courier/courier.go index 9f549efa0706..3e800408e560 100644 --- a/courier/courier.go +++ b/courier/courier.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/json" "fmt" + "net/url" "strconv" "time" @@ -16,27 +17,41 @@ import ( gomail "github.com/ory/mail/v3" - "github.com/ory/kratos/driver/config" "github.com/ory/kratos/x" ) type ( - smtpDependencies interface { + SMTPConfig interface { + CourierSMTPURL() *url.URL + CourierSMTPFrom() string + CourierSMTPFromName() string + CourierSMTPHeaders() map[string]string + CourierTemplatesRoot() string + } + SMTPDependencies interface { PersistenceProvider x.LoggingProvider - config.Provider + ConfigProvider } - Courier struct { - Dialer *gomail.Dialer - d smtpDependencies + TemplateTyper func(t EmailTemplate) (TemplateType, error) + EmailTemplateFromMessage func(c SMTPConfig, msg Message) (EmailTemplate, error) + Courier struct { + Dialer *gomail.Dialer + d SMTPDependencies + GetTemplateType TemplateTyper + NewEmailTemplateFromMessage EmailTemplateFromMessage } Provider interface { Courier(ctx context.Context) *Courier } + ConfigProvider interface { + CourierConfig(ctx context.Context) SMTPConfig + } ) -func NewSMTP(d smtpDependencies, c *config.Config) *Courier { - uri := c.CourierSMTPURL() +func NewSMTP(ctx context.Context, d SMTPDependencies) *Courier { + uri := d.CourierConfig(ctx).CourierSMTPURL() + password, _ := uri.User.Password() port, _ := strconv.ParseInt(uri.Port(), 10, 0) @@ -73,8 +88,10 @@ func NewSMTP(d smtpDependencies, c *config.Config) *Courier { } return &Courier{ - d: d, - Dialer: dialer, + d: d, + Dialer: dialer, + GetTemplateType: GetTemplateType, + NewEmailTemplateFromMessage: NewEmailTemplateFromMessage, } } @@ -94,7 +111,7 @@ func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e return uuid.Nil, err } - templateType, err := GetTemplateType(t) + templateType, err := m.GetTemplateType(t) if err != nil { return uuid.Nil, err } @@ -113,6 +130,7 @@ func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e TemplateType: templateType, TemplateData: templateData, } + if err := m.d.CourierPersister().AddMessage(ctx, message); err != nil { return uuid.Nil, err } @@ -151,8 +169,8 @@ func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error { switch msg.Type { case MessageTypeEmail: - from := m.d.Config(ctx).CourierSMTPFrom() - fromName := m.d.Config(ctx).CourierSMTPFromName() + from := m.d.CourierConfig(ctx).CourierSMTPFrom() + fromName := m.d.CourierConfig(ctx).CourierSMTPFromName() gm := gomail.NewMessage() if fromName == "" { gm.SetHeader("From", from) @@ -163,14 +181,14 @@ func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error { gm.SetHeader("To", msg.Recipient) gm.SetHeader("Subject", msg.Subject) - headers := m.d.Config(ctx).CourierSMTPHeaders() + headers := m.d.CourierConfig(ctx).CourierSMTPHeaders() for k, v := range headers { gm.SetHeader(k, v) } gm.SetBody("text/plain", msg.Body) - tmpl, err := NewEmailTemplateFromMessage(m.d.Config(ctx), msg) + tmpl, err := m.NewEmailTemplateFromMessage(m.d.CourierConfig(ctx), msg) if err != nil { m.d.Logger(). WithError(err). diff --git a/courier/courier_test.go b/courier/courier_test.go index 86499acfcd0b..866a975f30cb 100644 --- a/courier/courier_test.go +++ b/courier/courier_test.go @@ -35,11 +35,13 @@ func TestMain(m *testing.M) { } func TestNewSMTP(t *testing.T) { + ctx := context.Background() + setupConfig := func(stringURL string) *courier.Courier { - conf, _ := internal.NewFastRegistryWithMocks(t) + conf, reg := internal.NewFastRegistryWithMocks(t) conf.MustSet(config.ViperKeyCourierSMTPURL, stringURL) t.Logf("SMTP URL: %s", conf.CourierSMTPURL().String()) - return courier.NewSMTP(nil, conf) + return courier.NewSMTP(ctx, reg) } if testing.Short() { diff --git a/courier/persistence.go b/courier/persistence.go index c3202599397f..54811daefe4b 100644 --- a/courier/persistence.go +++ b/courier/persistence.go @@ -19,7 +19,6 @@ type ( LatestQueuedMessage(ctx context.Context) (*Message, error) } - PersistenceProvider interface { CourierPersister() Persister } diff --git a/courier/template/recovery_invalid.go b/courier/template/recovery_invalid.go index d5a768689dc6..75ea09ed3122 100644 --- a/courier/template/recovery_invalid.go +++ b/courier/template/recovery_invalid.go @@ -2,13 +2,11 @@ package template import ( "encoding/json" - - "github.com/ory/kratos/driver/config" ) type ( RecoveryInvalid struct { - c *config.Config + c TemplateConfig m *RecoveryInvalidModel } RecoveryInvalidModel struct { @@ -16,7 +14,7 @@ type ( } ) -func NewRecoveryInvalid(c *config.Config, m *RecoveryInvalidModel) *RecoveryInvalid { +func NewRecoveryInvalid(c TemplateConfig, m *RecoveryInvalidModel) *RecoveryInvalid { return &RecoveryInvalid{c: c, m: m} } diff --git a/courier/template/recovery_valid.go b/courier/template/recovery_valid.go index 51e58ac4771c..5a84a62bf873 100644 --- a/courier/template/recovery_valid.go +++ b/courier/template/recovery_valid.go @@ -2,13 +2,11 @@ package template import ( "encoding/json" - - "github.com/ory/kratos/driver/config" ) type ( RecoveryValid struct { - c *config.Config + c TemplateConfig m *RecoveryValidModel } RecoveryValidModel struct { @@ -18,7 +16,7 @@ type ( } ) -func NewRecoveryValid(c *config.Config, m *RecoveryValidModel) *RecoveryValid { +func NewRecoveryValid(c TemplateConfig, m *RecoveryValidModel) *RecoveryValid { return &RecoveryValid{c: c, m: m} } diff --git a/courier/template/stub.go b/courier/template/stub.go index 455a60de36f4..87f28a6415d7 100644 --- a/courier/template/stub.go +++ b/courier/template/stub.go @@ -2,12 +2,10 @@ package template import ( "encoding/json" - - "github.com/ory/kratos/driver/config" ) type TestStub struct { - c *config.Config + c TemplateConfig m *TestStubModel } @@ -17,7 +15,7 @@ type TestStubModel struct { Body string } -func NewTestStub(c *config.Config, m *TestStubModel) *TestStub { +func NewTestStub(c TemplateConfig, m *TestStubModel) *TestStub { return &TestStub{c: c, m: m} } diff --git a/courier/template/template.go b/courier/template/template.go new file mode 100644 index 000000000000..0486356fb0e8 --- /dev/null +++ b/courier/template/template.go @@ -0,0 +1,7 @@ +package template + +type ( + TemplateConfig interface { + CourierTemplatesRoot() string + } +) diff --git a/courier/template/verification_invalid.go b/courier/template/verification_invalid.go index 9b6736e56d65..aafc23d242d8 100644 --- a/courier/template/verification_invalid.go +++ b/courier/template/verification_invalid.go @@ -2,13 +2,11 @@ package template import ( "encoding/json" - - "github.com/ory/kratos/driver/config" ) type ( VerificationInvalid struct { - c *config.Config + c TemplateConfig m *VerificationInvalidModel } VerificationInvalidModel struct { @@ -16,7 +14,7 @@ type ( } ) -func NewVerificationInvalid(c *config.Config, m *VerificationInvalidModel) *VerificationInvalid { +func NewVerificationInvalid(c TemplateConfig, m *VerificationInvalidModel) *VerificationInvalid { return &VerificationInvalid{c: c, m: m} } diff --git a/courier/template/verification_valid.go b/courier/template/verification_valid.go index c3814b41d809..5d673646b4b4 100644 --- a/courier/template/verification_valid.go +++ b/courier/template/verification_valid.go @@ -2,13 +2,11 @@ package template import ( "encoding/json" - - "github.com/ory/kratos/driver/config" ) type ( VerificationValid struct { - c *config.Config + c TemplateConfig m *VerificationValidModel } VerificationValidModel struct { @@ -18,7 +16,7 @@ type ( } ) -func NewVerificationValid(c *config.Config, m *VerificationValidModel) *VerificationValid { +func NewVerificationValid(c TemplateConfig, m *VerificationValidModel) *VerificationValid { return &VerificationValid{c: c, m: m} } diff --git a/courier/templates.go b/courier/templates.go index f12cee294459..e04da43e4cb3 100644 --- a/courier/templates.go +++ b/courier/templates.go @@ -6,10 +6,18 @@ import ( "github.com/pkg/errors" "github.com/ory/kratos/courier/template" - "github.com/ory/kratos/driver/config" ) -type TemplateType string +type ( + TemplateType string + EmailTemplate interface { + json.Marshaler + EmailSubject() (string, error) + EmailBody() (string, error) + EmailBodyPlaintext() (string, error) + EmailRecipient() (string, error) + } +) const ( TypeRecoveryInvalid TemplateType = "recovery_invalid" @@ -19,14 +27,6 @@ const ( TypeTestStub TemplateType = "stub" ) -type EmailTemplate interface { - json.Marshaler - EmailSubject() (string, error) - EmailBody() (string, error) - EmailBodyPlaintext() (string, error) - EmailRecipient() (string, error) -} - func GetTemplateType(t EmailTemplate) (TemplateType, error) { switch t.(type) { case *template.RecoveryInvalid: @@ -44,39 +44,39 @@ func GetTemplateType(t EmailTemplate) (TemplateType, error) { } } -func NewEmailTemplateFromMessage(c *config.Config, m Message) (EmailTemplate, error) { - switch m.TemplateType { +func NewEmailTemplateFromMessage(c SMTPConfig, msg Message) (EmailTemplate, error) { + switch msg.TemplateType { case TypeRecoveryInvalid: var t template.RecoveryInvalidModel - if err := json.Unmarshal(m.TemplateData, &t); err != nil { + if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } return template.NewRecoveryInvalid(c, &t), nil case TypeRecoveryValid: var t template.RecoveryValidModel - if err := json.Unmarshal(m.TemplateData, &t); err != nil { + if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } return template.NewRecoveryValid(c, &t), nil case TypeVerificationInvalid: var t template.VerificationInvalidModel - if err := json.Unmarshal(m.TemplateData, &t); err != nil { + if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } return template.NewVerificationInvalid(c, &t), nil case TypeVerificationValid: var t template.VerificationValidModel - if err := json.Unmarshal(m.TemplateData, &t); err != nil { + if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } return template.NewVerificationValid(c, &t), nil case TypeTestStub: var t template.TestStubModel - if err := json.Unmarshal(m.TemplateData, &t); err != nil { + if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } return template.NewTestStub(c, &t), nil default: - return nil, errors.Errorf("received unexpected message template type: %s", m.TemplateType) + return nil, errors.Errorf("received unexpected message template type: %s", msg.TemplateType) } } diff --git a/driver/registry_default.go b/driver/registry_default.go index c6fe37a9eeba..56b00a024964 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -7,6 +7,8 @@ import ( "sync" "time" + "github.com/gobuffalo/pop/v5" + "github.com/ory/nosurf" "github.com/ory/kratos/selfservice/strategy/webauthn" @@ -22,8 +24,6 @@ import ( prometheus "github.com/ory/x/prometheusx" - "github.com/gobuffalo/pop/v5" - "github.com/ory/kratos/cipher" "github.com/ory/kratos/continuity" "github.com/ory/kratos/hash" @@ -259,6 +259,14 @@ func (m *RegistryDefault) Config(ctx context.Context) *config.Config { return corp.ContextualizeConfig(ctx, m.c) } +func (m *RegistryDefault) CourierConfig(ctx context.Context) courier.SMTPConfig { + return m.Config(ctx) +} + +func (m *RegistryDefault) SMTPConfig(ctx context.Context) courier.SMTPConfig { + return m.Config(ctx) +} + func (m *RegistryDefault) selfServiceStrategies() []interface{} { if len(m.selfserviceStrategies) == 0 { m.selfserviceStrategies = []interface{}{ @@ -579,7 +587,7 @@ func (m *RegistryDefault) SetPersister(p persistence.Persister) { } func (m *RegistryDefault) Courier(ctx context.Context) *courier.Courier { - return courier.NewSMTP(m, m.Config(ctx)) + return courier.NewSMTP(ctx, m) } func (m *RegistryDefault) ContinuityManager() continuity.Manager { diff --git a/text/id.go b/text/id.go index c58a64f662d0..3347b4749fef 100644 --- a/text/id.go +++ b/text/id.go @@ -73,9 +73,9 @@ const ( ) const ( - InfoSelfServiceVerification ID = 1080000 + iota // 1070000 - InfoSelfServiceVerificationEmailSent // 1070001 - InfoSelfServiceVerificationSuccessful // 1070002 + InfoSelfServiceVerification ID = 1080000 + iota // 1080000 + InfoSelfServiceVerificationEmailSent // 1080001 + InfoSelfServiceVerificationSuccessful // 1080002 ) const (