Skip to content

Commit

Permalink
feat: generalise courier (ory#2019)
Browse files Browse the repository at this point in the history
  • Loading branch information
Benehiko authored Dec 6, 2021
1 parent d8cfaf2 commit f439114
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 62 deletions.
48 changes: 33 additions & 15 deletions courier/courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"net/url"
"strconv"
"time"

Expand All @@ -16,27 +17,41 @@ import (

gomail "github.com/ory/mail/v3"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/x"
)

type (
smtpDependencies interface {
SMTPConfig interface {
CourierSMTPURL() *url.URL
CourierSMTPFrom() string
CourierSMTPFromName() string
CourierSMTPHeaders() map[string]string
CourierTemplatesRoot() string
}
SMTPDependencies interface {
PersistenceProvider
x.LoggingProvider
config.Provider
ConfigProvider
}
Courier struct {
Dialer *gomail.Dialer
d smtpDependencies
TemplateTyper func(t EmailTemplate) (TemplateType, error)
EmailTemplateFromMessage func(c SMTPConfig, msg Message) (EmailTemplate, error)
Courier struct {
Dialer *gomail.Dialer
d SMTPDependencies
GetTemplateType TemplateTyper
NewEmailTemplateFromMessage EmailTemplateFromMessage
}
Provider interface {
Courier(ctx context.Context) *Courier
}
ConfigProvider interface {
CourierConfig(ctx context.Context) SMTPConfig
}
)

func NewSMTP(d smtpDependencies, c *config.Config) *Courier {
uri := c.CourierSMTPURL()
func NewSMTP(ctx context.Context, d SMTPDependencies) *Courier {
uri := d.CourierConfig(ctx).CourierSMTPURL()

password, _ := uri.User.Password()
port, _ := strconv.ParseInt(uri.Port(), 10, 0)

Expand Down Expand Up @@ -73,8 +88,10 @@ func NewSMTP(d smtpDependencies, c *config.Config) *Courier {
}

return &Courier{
d: d,
Dialer: dialer,
d: d,
Dialer: dialer,
GetTemplateType: GetTemplateType,
NewEmailTemplateFromMessage: NewEmailTemplateFromMessage,
}
}

Expand All @@ -94,7 +111,7 @@ func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e
return uuid.Nil, err
}

templateType, err := GetTemplateType(t)
templateType, err := m.GetTemplateType(t)
if err != nil {
return uuid.Nil, err
}
Expand All @@ -113,6 +130,7 @@ func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e
TemplateType: templateType,
TemplateData: templateData,
}

if err := m.d.CourierPersister().AddMessage(ctx, message); err != nil {
return uuid.Nil, err
}
Expand Down Expand Up @@ -151,8 +169,8 @@ func (m *Courier) watchMessages(ctx context.Context, errChan chan error) {
func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error {
switch msg.Type {
case MessageTypeEmail:
from := m.d.Config(ctx).CourierSMTPFrom()
fromName := m.d.Config(ctx).CourierSMTPFromName()
from := m.d.CourierConfig(ctx).CourierSMTPFrom()
fromName := m.d.CourierConfig(ctx).CourierSMTPFromName()
gm := gomail.NewMessage()
if fromName == "" {
gm.SetHeader("From", from)
Expand All @@ -163,14 +181,14 @@ func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error {
gm.SetHeader("To", msg.Recipient)
gm.SetHeader("Subject", msg.Subject)

headers := m.d.Config(ctx).CourierSMTPHeaders()
headers := m.d.CourierConfig(ctx).CourierSMTPHeaders()
for k, v := range headers {
gm.SetHeader(k, v)
}

gm.SetBody("text/plain", msg.Body)

tmpl, err := NewEmailTemplateFromMessage(m.d.Config(ctx), msg)
tmpl, err := m.NewEmailTemplateFromMessage(m.d.CourierConfig(ctx), msg)
if err != nil {
m.d.Logger().
WithError(err).
Expand Down
6 changes: 4 additions & 2 deletions courier/courier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ func TestMain(m *testing.M) {
}

func TestNewSMTP(t *testing.T) {
ctx := context.Background()

setupConfig := func(stringURL string) *courier.Courier {
conf, _ := internal.NewFastRegistryWithMocks(t)
conf, reg := internal.NewFastRegistryWithMocks(t)
conf.MustSet(config.ViperKeyCourierSMTPURL, stringURL)
t.Logf("SMTP URL: %s", conf.CourierSMTPURL().String())
return courier.NewSMTP(nil, conf)
return courier.NewSMTP(ctx, reg)
}

if testing.Short() {
Expand Down
1 change: 0 additions & 1 deletion courier/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ type (

LatestQueuedMessage(ctx context.Context) (*Message, error)
}

PersistenceProvider interface {
CourierPersister() Persister
}
Expand Down
6 changes: 2 additions & 4 deletions courier/template/recovery_invalid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@ package template

import (
"encoding/json"

"github.com/ory/kratos/driver/config"
)

type (
RecoveryInvalid struct {
c *config.Config
c TemplateConfig
m *RecoveryInvalidModel
}
RecoveryInvalidModel struct {
To string
}
)

func NewRecoveryInvalid(c *config.Config, m *RecoveryInvalidModel) *RecoveryInvalid {
func NewRecoveryInvalid(c TemplateConfig, m *RecoveryInvalidModel) *RecoveryInvalid {
return &RecoveryInvalid{c: c, m: m}
}

Expand Down
6 changes: 2 additions & 4 deletions courier/template/recovery_valid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package template

import (
"encoding/json"

"github.com/ory/kratos/driver/config"
)

type (
RecoveryValid struct {
c *config.Config
c TemplateConfig
m *RecoveryValidModel
}
RecoveryValidModel struct {
Expand All @@ -18,7 +16,7 @@ type (
}
)

func NewRecoveryValid(c *config.Config, m *RecoveryValidModel) *RecoveryValid {
func NewRecoveryValid(c TemplateConfig, m *RecoveryValidModel) *RecoveryValid {
return &RecoveryValid{c: c, m: m}
}

Expand Down
6 changes: 2 additions & 4 deletions courier/template/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ package template

import (
"encoding/json"

"github.com/ory/kratos/driver/config"
)

type TestStub struct {
c *config.Config
c TemplateConfig
m *TestStubModel
}

Expand All @@ -17,7 +15,7 @@ type TestStubModel struct {
Body string
}

func NewTestStub(c *config.Config, m *TestStubModel) *TestStub {
func NewTestStub(c TemplateConfig, m *TestStubModel) *TestStub {
return &TestStub{c: c, m: m}
}

Expand Down
7 changes: 7 additions & 0 deletions courier/template/template.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package template

type (
TemplateConfig interface {
CourierTemplatesRoot() string
}
)
6 changes: 2 additions & 4 deletions courier/template/verification_invalid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@ package template

import (
"encoding/json"

"github.com/ory/kratos/driver/config"
)

type (
VerificationInvalid struct {
c *config.Config
c TemplateConfig
m *VerificationInvalidModel
}
VerificationInvalidModel struct {
To string
}
)

func NewVerificationInvalid(c *config.Config, m *VerificationInvalidModel) *VerificationInvalid {
func NewVerificationInvalid(c TemplateConfig, m *VerificationInvalidModel) *VerificationInvalid {
return &VerificationInvalid{c: c, m: m}
}

Expand Down
6 changes: 2 additions & 4 deletions courier/template/verification_valid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package template

import (
"encoding/json"

"github.com/ory/kratos/driver/config"
)

type (
VerificationValid struct {
c *config.Config
c TemplateConfig
m *VerificationValidModel
}
VerificationValidModel struct {
Expand All @@ -18,7 +16,7 @@ type (
}
)

func NewVerificationValid(c *config.Config, m *VerificationValidModel) *VerificationValid {
func NewVerificationValid(c TemplateConfig, m *VerificationValidModel) *VerificationValid {
return &VerificationValid{c: c, m: m}
}

Expand Down
36 changes: 18 additions & 18 deletions courier/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@ import (
"github.com/pkg/errors"

"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/driver/config"
)

type TemplateType string
type (
TemplateType string
EmailTemplate interface {
json.Marshaler
EmailSubject() (string, error)
EmailBody() (string, error)
EmailBodyPlaintext() (string, error)
EmailRecipient() (string, error)
}
)

const (
TypeRecoveryInvalid TemplateType = "recovery_invalid"
Expand All @@ -19,14 +27,6 @@ const (
TypeTestStub TemplateType = "stub"
)

type EmailTemplate interface {
json.Marshaler
EmailSubject() (string, error)
EmailBody() (string, error)
EmailBodyPlaintext() (string, error)
EmailRecipient() (string, error)
}

func GetTemplateType(t EmailTemplate) (TemplateType, error) {
switch t.(type) {
case *template.RecoveryInvalid:
Expand All @@ -44,39 +44,39 @@ func GetTemplateType(t EmailTemplate) (TemplateType, error) {
}
}

func NewEmailTemplateFromMessage(c *config.Config, m Message) (EmailTemplate, error) {
switch m.TemplateType {
func NewEmailTemplateFromMessage(c SMTPConfig, msg Message) (EmailTemplate, error) {
switch msg.TemplateType {
case TypeRecoveryInvalid:
var t template.RecoveryInvalidModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return template.NewRecoveryInvalid(c, &t), nil
case TypeRecoveryValid:
var t template.RecoveryValidModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return template.NewRecoveryValid(c, &t), nil
case TypeVerificationInvalid:
var t template.VerificationInvalidModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return template.NewVerificationInvalid(c, &t), nil
case TypeVerificationValid:
var t template.VerificationValidModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return template.NewVerificationValid(c, &t), nil
case TypeTestStub:
var t template.TestStubModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return template.NewTestStub(c, &t), nil
default:
return nil, errors.Errorf("received unexpected message template type: %s", m.TemplateType)
return nil, errors.Errorf("received unexpected message template type: %s", msg.TemplateType)
}
}
14 changes: 11 additions & 3 deletions driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"sync"
"time"

"github.com/gobuffalo/pop/v5"

"github.com/ory/nosurf"

"github.com/ory/kratos/selfservice/strategy/webauthn"
Expand All @@ -22,8 +24,6 @@ import (

prometheus "github.com/ory/x/prometheusx"

"github.com/gobuffalo/pop/v5"

"github.com/ory/kratos/cipher"
"github.com/ory/kratos/continuity"
"github.com/ory/kratos/hash"
Expand Down Expand Up @@ -259,6 +259,14 @@ func (m *RegistryDefault) Config(ctx context.Context) *config.Config {
return corp.ContextualizeConfig(ctx, m.c)
}

func (m *RegistryDefault) CourierConfig(ctx context.Context) courier.SMTPConfig {
return m.Config(ctx)
}

func (m *RegistryDefault) SMTPConfig(ctx context.Context) courier.SMTPConfig {
return m.Config(ctx)
}

func (m *RegistryDefault) selfServiceStrategies() []interface{} {
if len(m.selfserviceStrategies) == 0 {
m.selfserviceStrategies = []interface{}{
Expand Down Expand Up @@ -579,7 +587,7 @@ func (m *RegistryDefault) SetPersister(p persistence.Persister) {
}

func (m *RegistryDefault) Courier(ctx context.Context) *courier.Courier {
return courier.NewSMTP(m, m.Config(ctx))
return courier.NewSMTP(ctx, m)
}

func (m *RegistryDefault) ContinuityManager() continuity.Manager {
Expand Down
Loading

0 comments on commit f439114

Please sign in to comment.