Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generalise courier #2019

Merged
merged 4 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions courier/courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"net/url"
"strconv"
"time"

Expand All @@ -16,27 +17,41 @@ import (

gomail "github.com/ory/mail/v3"

"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/x"
)

type (
smtpDependencies interface {
SMTPConfig interface {
CourierSMTPURL() *url.URL
CourierSMTPFrom() string
CourierSMTPFromName() string
CourierSMTPHeaders() map[string]string
CourierTemplatesRoot() string
}
SMTPDependencies interface {
PersistenceProvider
x.LoggingProvider
config.Provider
ConfigProvider
}
Courier struct {
Dialer *gomail.Dialer
d smtpDependencies
TemplateTyper func(t EmailTemplate) (TemplateType, error)
EmailTemplateFromMessage func(c SMTPConfig, msg Message) (EmailTemplate, error)
Courier struct {
Dialer *gomail.Dialer
d SMTPDependencies
GetTemplateType TemplateTyper
NewEmailTemplateFromMessage EmailTemplateFromMessage
Benehiko marked this conversation as resolved.
Show resolved Hide resolved
}
Provider interface {
Courier(ctx context.Context) *Courier
}
ConfigProvider interface {
CourierConfig(ctx context.Context) SMTPConfig
}
)

func NewSMTP(d smtpDependencies, c *config.Config) *Courier {
uri := c.CourierSMTPURL()
func NewSMTP(ctx context.Context, d SMTPDependencies) *Courier {
uri := d.CourierConfig(ctx).CourierSMTPURL()

password, _ := uri.User.Password()
port, _ := strconv.ParseInt(uri.Port(), 10, 0)

Expand Down Expand Up @@ -73,8 +88,10 @@ func NewSMTP(d smtpDependencies, c *config.Config) *Courier {
}

return &Courier{
d: d,
Dialer: dialer,
d: d,
Dialer: dialer,
GetTemplateType: GetTemplateType,
NewEmailTemplateFromMessage: NewEmailTemplateFromMessage,
}
}

Expand All @@ -94,7 +111,7 @@ func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e
return uuid.Nil, err
}

templateType, err := GetTemplateType(t)
templateType, err := m.GetTemplateType(t)
if err != nil {
return uuid.Nil, err
}
Expand All @@ -113,6 +130,7 @@ func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e
TemplateType: templateType,
TemplateData: templateData,
}

if err := m.d.CourierPersister().AddMessage(ctx, message); err != nil {
return uuid.Nil, err
}
Expand Down Expand Up @@ -151,8 +169,8 @@ func (m *Courier) watchMessages(ctx context.Context, errChan chan error) {
func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error {
switch msg.Type {
case MessageTypeEmail:
from := m.d.Config(ctx).CourierSMTPFrom()
fromName := m.d.Config(ctx).CourierSMTPFromName()
from := m.d.CourierConfig(ctx).CourierSMTPFrom()
fromName := m.d.CourierConfig(ctx).CourierSMTPFromName()
gm := gomail.NewMessage()
if fromName == "" {
gm.SetHeader("From", from)
Expand All @@ -163,14 +181,14 @@ func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error {
gm.SetHeader("To", msg.Recipient)
gm.SetHeader("Subject", msg.Subject)

headers := m.d.Config(ctx).CourierSMTPHeaders()
headers := m.d.CourierConfig(ctx).CourierSMTPHeaders()
for k, v := range headers {
gm.SetHeader(k, v)
}

gm.SetBody("text/plain", msg.Body)

tmpl, err := NewEmailTemplateFromMessage(m.d.Config(ctx), msg)
tmpl, err := m.NewEmailTemplateFromMessage(m.d.CourierConfig(ctx), msg)
if err != nil {
m.d.Logger().
WithError(err).
Expand Down
6 changes: 4 additions & 2 deletions courier/courier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ func TestMain(m *testing.M) {
}

func TestNewSMTP(t *testing.T) {
ctx := context.Background()

setupConfig := func(stringURL string) *courier.Courier {
conf, _ := internal.NewFastRegistryWithMocks(t)
conf, reg := internal.NewFastRegistryWithMocks(t)
conf.MustSet(config.ViperKeyCourierSMTPURL, stringURL)
t.Logf("SMTP URL: %s", conf.CourierSMTPURL().String())
return courier.NewSMTP(nil, conf)
return courier.NewSMTP(ctx, reg)
}

if testing.Short() {
Expand Down
1 change: 0 additions & 1 deletion courier/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ type (

LatestQueuedMessage(ctx context.Context) (*Message, error)
}

PersistenceProvider interface {
CourierPersister() Persister
}
Expand Down
6 changes: 2 additions & 4 deletions courier/template/recovery_invalid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@ package template

import (
"encoding/json"

"github.com/ory/kratos/driver/config"
)

type (
RecoveryInvalid struct {
c *config.Config
c TemplateConfig
m *RecoveryInvalidModel
}
RecoveryInvalidModel struct {
To string
}
)

func NewRecoveryInvalid(c *config.Config, m *RecoveryInvalidModel) *RecoveryInvalid {
func NewRecoveryInvalid(c TemplateConfig, m *RecoveryInvalidModel) *RecoveryInvalid {
return &RecoveryInvalid{c: c, m: m}
}

Expand Down
6 changes: 2 additions & 4 deletions courier/template/recovery_valid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package template

import (
"encoding/json"

"github.com/ory/kratos/driver/config"
)

type (
RecoveryValid struct {
c *config.Config
c TemplateConfig
m *RecoveryValidModel
}
RecoveryValidModel struct {
Expand All @@ -18,7 +16,7 @@ type (
}
)

func NewRecoveryValid(c *config.Config, m *RecoveryValidModel) *RecoveryValid {
func NewRecoveryValid(c TemplateConfig, m *RecoveryValidModel) *RecoveryValid {
return &RecoveryValid{c: c, m: m}
}

Expand Down
6 changes: 2 additions & 4 deletions courier/template/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ package template

import (
"encoding/json"

"github.com/ory/kratos/driver/config"
)

type TestStub struct {
c *config.Config
c TemplateConfig
m *TestStubModel
}

Expand All @@ -17,7 +15,7 @@ type TestStubModel struct {
Body string
}

func NewTestStub(c *config.Config, m *TestStubModel) *TestStub {
func NewTestStub(c TemplateConfig, m *TestStubModel) *TestStub {
return &TestStub{c: c, m: m}
}

Expand Down
7 changes: 7 additions & 0 deletions courier/template/template.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package template

type (
TemplateConfig interface {
CourierTemplatesRoot() string
}
)
6 changes: 2 additions & 4 deletions courier/template/verification_invalid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@ package template

import (
"encoding/json"

"github.com/ory/kratos/driver/config"
)

type (
VerificationInvalid struct {
c *config.Config
c TemplateConfig
m *VerificationInvalidModel
}
VerificationInvalidModel struct {
To string
}
)

func NewVerificationInvalid(c *config.Config, m *VerificationInvalidModel) *VerificationInvalid {
func NewVerificationInvalid(c TemplateConfig, m *VerificationInvalidModel) *VerificationInvalid {
return &VerificationInvalid{c: c, m: m}
}

Expand Down
6 changes: 2 additions & 4 deletions courier/template/verification_valid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package template

import (
"encoding/json"

"github.com/ory/kratos/driver/config"
)

type (
VerificationValid struct {
c *config.Config
c TemplateConfig
m *VerificationValidModel
}
VerificationValidModel struct {
Expand All @@ -18,7 +16,7 @@ type (
}
)

func NewVerificationValid(c *config.Config, m *VerificationValidModel) *VerificationValid {
func NewVerificationValid(c TemplateConfig, m *VerificationValidModel) *VerificationValid {
return &VerificationValid{c: c, m: m}
}

Expand Down
36 changes: 18 additions & 18 deletions courier/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@ import (
"github.com/pkg/errors"

"github.com/ory/kratos/courier/template"
"github.com/ory/kratos/driver/config"
)

type TemplateType string
type (
TemplateType string
EmailTemplate interface {
json.Marshaler
EmailSubject() (string, error)
EmailBody() (string, error)
EmailBodyPlaintext() (string, error)
EmailRecipient() (string, error)
}
)

const (
TypeRecoveryInvalid TemplateType = "recovery_invalid"
Expand All @@ -19,14 +27,6 @@ const (
TypeTestStub TemplateType = "stub"
)

type EmailTemplate interface {
json.Marshaler
EmailSubject() (string, error)
EmailBody() (string, error)
EmailBodyPlaintext() (string, error)
EmailRecipient() (string, error)
}

func GetTemplateType(t EmailTemplate) (TemplateType, error) {
switch t.(type) {
case *template.RecoveryInvalid:
Expand All @@ -44,39 +44,39 @@ func GetTemplateType(t EmailTemplate) (TemplateType, error) {
}
}

func NewEmailTemplateFromMessage(c *config.Config, m Message) (EmailTemplate, error) {
switch m.TemplateType {
func NewEmailTemplateFromMessage(c SMTPConfig, msg Message) (EmailTemplate, error) {
switch msg.TemplateType {
case TypeRecoveryInvalid:
var t template.RecoveryInvalidModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return template.NewRecoveryInvalid(c, &t), nil
case TypeRecoveryValid:
var t template.RecoveryValidModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return template.NewRecoveryValid(c, &t), nil
case TypeVerificationInvalid:
var t template.VerificationInvalidModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return template.NewVerificationInvalid(c, &t), nil
case TypeVerificationValid:
var t template.VerificationValidModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return template.NewVerificationValid(c, &t), nil
case TypeTestStub:
var t template.TestStubModel
if err := json.Unmarshal(m.TemplateData, &t); err != nil {
if err := json.Unmarshal(msg.TemplateData, &t); err != nil {
return nil, err
}
return template.NewTestStub(c, &t), nil
default:
return nil, errors.Errorf("received unexpected message template type: %s", m.TemplateType)
return nil, errors.Errorf("received unexpected message template type: %s", msg.TemplateType)
}
}
14 changes: 11 additions & 3 deletions driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"sync"
"time"

"github.com/gobuffalo/pop/v5"

"github.com/ory/nosurf"

"github.com/ory/kratos/selfservice/strategy/webauthn"
Expand All @@ -22,8 +24,6 @@ import (

prometheus "github.com/ory/x/prometheusx"

"github.com/gobuffalo/pop/v5"

"github.com/ory/kratos/cipher"
"github.com/ory/kratos/continuity"
"github.com/ory/kratos/hash"
Expand Down Expand Up @@ -259,6 +259,14 @@ func (m *RegistryDefault) Config(ctx context.Context) *config.Config {
return corp.ContextualizeConfig(ctx, m.c)
}

func (m *RegistryDefault) CourierConfig(ctx context.Context) courier.SMTPConfig {
return m.Config(ctx)
}

func (m *RegistryDefault) SMTPConfig(ctx context.Context) courier.SMTPConfig {
return m.Config(ctx)
}

func (m *RegistryDefault) selfServiceStrategies() []interface{} {
if len(m.selfserviceStrategies) == 0 {
m.selfserviceStrategies = []interface{}{
Expand Down Expand Up @@ -579,7 +587,7 @@ func (m *RegistryDefault) SetPersister(p persistence.Persister) {
}

func (m *RegistryDefault) Courier(ctx context.Context) *courier.Courier {
return courier.NewSMTP(m, m.Config(ctx))
return courier.NewSMTP(ctx, m)
}

func (m *RegistryDefault) ContinuityManager() continuity.Manager {
Expand Down
Loading