diff --git a/courier/courier.go b/courier/courier.go index 3e800408e560..2321696251ae 100644 --- a/courier/courier.go +++ b/courier/courier.go @@ -2,146 +2,71 @@ package courier import ( "context" - "crypto/tls" "encoding/json" - "fmt" "net/url" - "strconv" "time" "github.com/cenkalti/backoff" "github.com/gofrs/uuid" "github.com/pkg/errors" - "github.com/ory/herodot" - - gomail "github.com/ory/mail/v3" - "github.com/ory/kratos/x" + gomail "github.com/ory/mail/v3" ) type ( - SMTPConfig interface { + Config interface { CourierSMTPURL() *url.URL CourierSMTPFrom() string CourierSMTPFromName() string CourierSMTPHeaders() map[string]string + CourierSMSEnabled() bool + CourierSMSRequestConfig() json.RawMessage + CourierSMSFrom() string CourierTemplatesRoot() string } - SMTPDependencies interface { - PersistenceProvider - x.LoggingProvider - ConfigProvider - } - TemplateTyper func(t EmailTemplate) (TemplateType, error) - EmailTemplateFromMessage func(c SMTPConfig, msg Message) (EmailTemplate, error) - Courier struct { - Dialer *gomail.Dialer - d SMTPDependencies - GetTemplateType TemplateTyper - NewEmailTemplateFromMessage EmailTemplateFromMessage - } - Provider interface { - Courier(ctx context.Context) *Courier - } - ConfigProvider interface { - CourierConfig(ctx context.Context) SMTPConfig - } -) - -func NewSMTP(ctx context.Context, d SMTPDependencies) *Courier { - uri := d.CourierConfig(ctx).CourierSMTPURL() - - password, _ := uri.User.Password() - port, _ := strconv.ParseInt(uri.Port(), 10, 0) - dialer := &gomail.Dialer{ - Host: uri.Hostname(), - Port: int(port), - Username: uri.User.Username(), - Password: password, - - Timeout: time.Second * 10, - RetryFailure: true, - } - - sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify")) - - // SMTP schemes - // smtp: smtp clear text (with uri parameter) or with StartTLS (enforced by default) - // smtps: smtp with implicit TLS (recommended way in 2021 to avoid StartTLS downgrade attacks - // and defaulting to fully-encrypted protocols https://datatracker.ietf.org/doc/html/rfc8314) - switch uri.Scheme { - case "smtp": - // Enforcing StartTLS by default for security best practices (config review, etc.) - skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls")) - if !skipStartTLS { - // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. - dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} - // Enforcing StartTLS - dialer.StartTLSPolicy = gomail.MandatoryStartTLS - } - case "smtps": - // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. - dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} - dialer.SSL = true - } - - return &Courier{ - d: d, - Dialer: dialer, - GetTemplateType: GetTemplateType, - NewEmailTemplateFromMessage: NewEmailTemplateFromMessage, - } -} - -func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) { - recipient, err := t.EmailRecipient() - if err != nil { - return uuid.Nil, err - } - - subject, err := t.EmailSubject() - if err != nil { - return uuid.Nil, err + ConfigProvider interface { + CourierConfig(ctx context.Context) Config } - bodyPlaintext, err := t.EmailBodyPlaintext() - if err != nil { - return uuid.Nil, err + Dependencies interface { + PersistenceProvider + x.LoggingProvider + ConfigProvider } - templateType, err := m.GetTemplateType(t) - if err != nil { - return uuid.Nil, err + Courier interface { + Work(ctx context.Context) error + QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) + QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error) + SmtpDialer() *gomail.Dialer } - templateData, err := json.Marshal(t) - if err != nil { - return uuid.Nil, err + courier struct { + smsClient *smsClient + smtpClient *smtpClient + deps Dependencies } - message := &Message{ - Status: MessageStatusQueued, - Type: MessageTypeEmail, - Recipient: recipient, - Body: bodyPlaintext, - Subject: subject, - TemplateType: templateType, - TemplateData: templateData, + Provider interface { + Courier(ctx context.Context) Courier } +) - if err := m.d.CourierPersister().AddMessage(ctx, message); err != nil { - return uuid.Nil, err +func NewCourier(ctx context.Context, deps Dependencies) Courier { + return &courier{ + smsClient: newSMS(ctx, deps), + smtpClient: newSMTP(ctx, deps), + deps: deps, } - return message.ID, nil } -func (m *Courier) Work(ctx context.Context) error { +func (c *courier) Work(ctx context.Context) error { errChan := make(chan error) defer close(errChan) - go m.watchMessages(ctx, errChan) + go c.watchMessages(ctx, errChan) select { case <-ctx.Done(): @@ -154,10 +79,10 @@ func (m *Courier) Work(ctx context.Context) error { } } -func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { +func (c *courier) watchMessages(ctx context.Context, errChan chan error) { for { if err := backoff.Retry(func() error { - return m.DispatchQueue(ctx) + return c.DispatchQueue(ctx) }, backoff.NewExponentialBackOff()); err != nil { errChan <- err return @@ -165,105 +90,3 @@ func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { time.Sleep(time.Second) } } - -func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error { - switch msg.Type { - case MessageTypeEmail: - from := m.d.CourierConfig(ctx).CourierSMTPFrom() - fromName := m.d.CourierConfig(ctx).CourierSMTPFromName() - gm := gomail.NewMessage() - if fromName == "" { - gm.SetHeader("From", from) - } else { - gm.SetAddressHeader("From", from, fromName) - } - - gm.SetHeader("To", msg.Recipient) - gm.SetHeader("Subject", msg.Subject) - - headers := m.d.CourierConfig(ctx).CourierSMTPHeaders() - for k, v := range headers { - gm.SetHeader(k, v) - } - - gm.SetBody("text/plain", msg.Body) - - tmpl, err := m.NewEmailTemplateFromMessage(m.d.CourierConfig(ctx), msg) - if err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to get email template from message.`) - } else { - htmlBody, err := tmpl.EmailBody() - if err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to get email body from template.`) - } else { - gm.AddAlternative("text/html", htmlBody) - } - } - - if err := m.Dialer.DialAndSend(ctx, gm); err != nil { - m.d.Logger(). - WithError(err). - WithField("smtp_server", fmt.Sprintf("%s:%d", m.Dialer.Host, m.Dialer.Port)). - WithField("smtp_ssl_enabled", m.Dialer.SSL). - // WithField("email_to", msg.Recipient). - WithField("message_from", from). - Error("Unable to send email using SMTP connection.") - return errors.WithStack(err) - } - - if err := m.d.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to set the message status to "sent".`) - return err - } - - m.d.Logger(). - WithField("message_id", msg.ID). - WithField("message_type", msg.Type). - WithField("message_template_type", msg.TemplateType). - WithField("message_subject", msg.Subject). - Debug("Courier sent out message.") - return nil - } - return errors.Errorf("received unexpected message type: %d", msg.Type) -} - -func (m *Courier) DispatchQueue(ctx context.Context) error { - if len(m.Dialer.Host) == 0 { - return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an email but courier.smtp_url is not set!")) - } - - messages, err := m.d.CourierPersister().NextMessages(ctx, 10) - if err != nil { - if errors.Is(err, ErrQueueEmpty) { - return nil - } - return err - } - - for k := range messages { - var msg = messages[k] - if err := m.DispatchMessage(ctx, msg); err != nil { - for _, replace := range messages[k:] { - if err := m.d.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", replace.ID). - Error(`Unable to reset the failed message's status to "queued".`) - } - } - - return err - } - } - - return nil -} diff --git a/courier/courier_test.go b/courier/courier_test.go index 866a975f30cb..b871e19fbe54 100644 --- a/courier/courier_test.go +++ b/courier/courier_test.go @@ -1,30 +1,10 @@ package courier_test import ( - "context" - "fmt" - "io/ioutil" - "net/http" "testing" - "time" - - "github.com/sirupsen/logrus" "github.com/ory/kratos/x" - gomail "github.com/ory/mail/v3" - - "github.com/gofrs/uuid" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" - dhelper "github.com/ory/x/sqlcon/dockertest" - - courier "github.com/ory/kratos/courier" - templates "github.com/ory/kratos/courier/template" - "github.com/ory/kratos/driver/config" - "github.com/ory/kratos/internal" ) // nolint:staticcheck @@ -33,132 +13,3 @@ func TestMain(m *testing.M) { atexit.Add(x.CleanUpTestSMTP) atexit.Exit(m.Run()) } - -func TestNewSMTP(t *testing.T) { - ctx := context.Background() - - setupConfig := func(stringURL string) *courier.Courier { - conf, reg := internal.NewFastRegistryWithMocks(t) - conf.MustSet(config.ViperKeyCourierSMTPURL, stringURL) - t.Logf("SMTP URL: %s", conf.CourierSMTPURL().String()) - return courier.NewSMTP(ctx, reg) - } - - if testing.Short() { - t.SkipNow() - } - - //Should enforce StartTLS => dialer.StartTLSPolicy = gomail.MandatoryStartTLS and dialer.SSL = false - smtp := setupConfig("smtp://foo:bar@my-server:1234/") - assert.Equal(t, smtp.Dialer.StartTLSPolicy, gomail.MandatoryStartTLS, "StartTLS not enforced") - assert.Equal(t, smtp.Dialer.SSL, false, "Implicit TLS should not be enabled") - - //Should enforce TLS => dialer.SSL = true - smtp = setupConfig("smtps://foo:bar@my-server:1234/") - assert.Equal(t, smtp.Dialer.SSL, true, "Implicit TLS should be enabled") - - //Should allow cleartext => dialer.StartTLSPolicy = gomail.OpportunisticStartTLS and dialer.SSL = false - smtp = setupConfig("smtp://foo:bar@my-server:1234/?disable_starttls=true") - assert.Equal(t, smtp.Dialer.StartTLSPolicy, gomail.OpportunisticStartTLS, "StartTLS is enforced") - assert.Equal(t, smtp.Dialer.SSL, false, "Implicit TLS should not be enabled") -} - -func TestSMTP(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - - smtp, api, err := x.RunTestSMTP() - require.NoError(t, err) - t.Logf("SMTP URL: %s", smtp) - t.Logf("API URL: %s", api) - - ctx := context.Background() - - conf, reg := internal.NewFastRegistryWithMocks(t) - conf.MustSet(config.ViperKeyCourierSMTPURL, smtp) - conf.MustSet(config.ViperKeyCourierSMTPFrom, "test-stub@ory.sh") - reg.Logger().Level = logrus.TraceLevel - - c := reg.Courier(ctx) - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - id, err := c.QueueEmail(ctx, templates.NewTestStub(conf, &templates.TestStubModel{ - To: "test-recipient-1@example.org", - Subject: "test-subject-1", - Body: "test-body-1", - })) - require.NoError(t, err) - require.NotEqual(t, uuid.Nil, id) - - id, err = c.QueueEmail(ctx, templates.NewTestStub(conf, &templates.TestStubModel{ - To: "test-recipient-2@example.org", - Subject: "test-subject-2", - Body: "test-body-2", - })) - require.NoError(t, err) - require.NotEqual(t, uuid.Nil, id) - - // The third email contains a sender name and custom headers - conf.MustSet(config.ViperKeyCourierSMTPFromName, "Bob") - conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header1", "foo") - conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header2", "bar") - customerHeaders := conf.CourierSMTPHeaders() - require.Len(t, customerHeaders, 2) - id, err = c.QueueEmail(ctx, templates.NewTestStub(conf, &templates.TestStubModel{ - To: "test-recipient-3@example.org", - Subject: "test-subject-3", - Body: "test-body-3", - })) - require.NoError(t, err) - require.NotEqual(t, uuid.Nil, id) - - go func() { - require.NoError(t, c.Work(ctx)) - }() - - var body []byte - for k := 0; k < 30; k++ { - time.Sleep(time.Second) - err = func() error { - res, err := http.Get(api + "/api/v2/messages") - if err != nil { - return err - } - - defer res.Body.Close() - body, err = ioutil.ReadAll(res.Body) - if err != nil { - return err - } - - if http.StatusOK != res.StatusCode { - return errors.Errorf("expected status code 200 but got %d with body: %s", res.StatusCode, body) - } - - if total := gjson.GetBytes(body, "total").Int(); total != 3 { - return errors.Errorf("expected to have delivered at least 3 messages but got count %d with body: %s", total, body) - } - - return nil - }() - if err == nil { - break - } - } - require.NoError(t, err) - - for k := 1; k <= 3; k++ { - assert.Contains(t, string(body), fmt.Sprintf("test-subject-%d", k)) - assert.Contains(t, string(body), fmt.Sprintf("test-body-%d", k)) - assert.Contains(t, string(body), fmt.Sprintf("test-recipient-%d@example.org", k)) - assert.Contains(t, string(body), "test-stub@ory.sh") - } - - // Assertion for the third email with sender name and headers - assert.Contains(t, string(body), "Bob") - assert.Contains(t, string(body), `"test-stub-header1":["foo"]`) - assert.Contains(t, string(body), `"test-stub-header2":["bar"]`) -} diff --git a/courier/dispatcher.go b/courier/dispatcher.go new file mode 100644 index 000000000000..4d8beb7f2fab --- /dev/null +++ b/courier/dispatcher.go @@ -0,0 +1,67 @@ +package courier + +import ( + "context" + + "github.com/pkg/errors" +) + +func (c *courier) DispatchMessage(ctx context.Context, msg Message) error { + switch msg.Type { + case MessageTypeEmail: + if err := c.dispatchEmail(ctx, msg); err != nil { + return err + } + case MessageTypePhone: + if err := c.dispatchSMS(ctx, msg); err != nil { + return err + } + default: + return errors.Errorf("received unexpected message type: %d", msg.Type) + } + + if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to set the message status to "sent".`) + return err + } + + c.deps.Logger(). + WithField("message_id", msg.ID). + WithField("message_type", msg.Type). + WithField("message_template_type", msg.TemplateType). + WithField("message_subject", msg.Subject). + Debug("Courier sent out message.") + + return nil +} + +func (c *courier) DispatchQueue(ctx context.Context) error { + messages, err := c.deps.CourierPersister().NextMessages(ctx, 10) + if err != nil { + if errors.Is(err, ErrQueueEmpty) { + return nil + } + return err + } + + for k := range messages { + var msg = messages[k] + if err := c.DispatchMessage(ctx, msg); err != nil { + for _, replace := range messages[k:] { + if err := c.deps.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", replace.ID). + Error(`Unable to reset the failed message's status to "queued".`) + } + } + + return err + } + } + + return nil +} diff --git a/courier/templates.go b/courier/email_templates.go similarity index 57% rename from courier/templates.go rename to courier/email_templates.go index e04da43e4cb3..f0e4c5e94c7d 100644 --- a/courier/templates.go +++ b/courier/email_templates.go @@ -5,77 +5,77 @@ import ( "github.com/pkg/errors" - "github.com/ory/kratos/courier/template" + "github.com/ory/kratos/courier/template/email" ) -type ( - TemplateType string - EmailTemplate interface { - json.Marshaler - EmailSubject() (string, error) - EmailBody() (string, error) - EmailBodyPlaintext() (string, error) - EmailRecipient() (string, error) - } -) +type TemplateType string const ( TypeRecoveryInvalid TemplateType = "recovery_invalid" TypeRecoveryValid TemplateType = "recovery_valid" TypeVerificationInvalid TemplateType = "verification_invalid" TypeVerificationValid TemplateType = "verification_valid" + TypeOTP TemplateType = "otp" TypeTestStub TemplateType = "stub" ) -func GetTemplateType(t EmailTemplate) (TemplateType, error) { +type EmailTemplate interface { + json.Marshaler + EmailSubject() (string, error) + EmailBody() (string, error) + EmailBodyPlaintext() (string, error) + EmailRecipient() (string, error) +} + +func GetEmailTemplateType(t EmailTemplate) (TemplateType, error) { switch t.(type) { - case *template.RecoveryInvalid: + case *email.RecoveryInvalid: return TypeRecoveryInvalid, nil - case *template.RecoveryValid: + case *email.RecoveryValid: return TypeRecoveryValid, nil - case *template.VerificationInvalid: + case *email.VerificationInvalid: return TypeVerificationInvalid, nil - case *template.VerificationValid: + case *email.VerificationValid: return TypeVerificationValid, nil - case *template.TestStub: + case *email.TestStub: return TypeTestStub, nil default: return "", errors.Errorf("unexpected template type") } } -func NewEmailTemplateFromMessage(c SMTPConfig, msg Message) (EmailTemplate, error) { +func NewEmailTemplateFromMessage(c Config, msg Message) (EmailTemplate, error) { switch msg.TemplateType { case TypeRecoveryInvalid: - var t template.RecoveryInvalidModel + var t email.RecoveryInvalidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewRecoveryInvalid(c, &t), nil + return email.NewRecoveryInvalid(c, &t), nil case TypeRecoveryValid: - var t template.RecoveryValidModel + var t email.RecoveryValidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewRecoveryValid(c, &t), nil + return email.NewRecoveryValid(c, &t), nil case TypeVerificationInvalid: - var t template.VerificationInvalidModel + var t email.VerificationInvalidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewVerificationInvalid(c, &t), nil + return email.NewVerificationInvalid(c, &t), nil case TypeVerificationValid: - var t template.VerificationValidModel + var t email.VerificationValidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewVerificationValid(c, &t), nil + return email.NewVerificationValid(c, &t), nil case TypeTestStub: - var t template.TestStubModel + var t email.TestStubModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewTestStub(c, &t), nil + return email.NewTestEmailStub(c, &t), nil default: return nil, errors.Errorf("received unexpected message template type: %s", msg.TemplateType) } diff --git a/courier/templates_test.go b/courier/email_templates_test.go similarity index 63% rename from courier/templates_test.go rename to courier/email_templates_test.go index ac6339482e72..a407e247eb6b 100644 --- a/courier/templates_test.go +++ b/courier/email_templates_test.go @@ -5,23 +5,24 @@ import ( "fmt" "testing" + "github.com/ory/kratos/courier/template/email" + "github.com/stretchr/testify/require" "github.com/ory/kratos/courier" - "github.com/ory/kratos/courier/template" "github.com/ory/kratos/internal" ) func TestGetTemplateType(t *testing.T) { for expectedType, tmpl := range map[courier.TemplateType]courier.EmailTemplate{ - courier.TypeRecoveryInvalid: &template.RecoveryInvalid{}, - courier.TypeRecoveryValid: &template.RecoveryValid{}, - courier.TypeVerificationInvalid: &template.VerificationInvalid{}, - courier.TypeVerificationValid: &template.VerificationValid{}, - courier.TypeTestStub: &template.TestStub{}, + courier.TypeRecoveryInvalid: &email.RecoveryInvalid{}, + courier.TypeRecoveryValid: &email.RecoveryValid{}, + courier.TypeVerificationInvalid: &email.VerificationInvalid{}, + courier.TypeVerificationValid: &email.VerificationValid{}, + courier.TypeTestStub: &email.TestStub{}, } { t.Run(fmt.Sprintf("case=%s", expectedType), func(t *testing.T) { - actualType, err := courier.GetTemplateType(tmpl) + actualType, err := courier.GetEmailTemplateType(tmpl) require.NoError(t, err) require.Equal(t, expectedType, actualType) @@ -33,11 +34,11 @@ func TestGetTemplateType(t *testing.T) { func TestNewEmailTemplateFromMessage(t *testing.T) { conf := internal.NewConfigurationWithDefaults(t) for tmplType, expectedTmpl := range map[courier.TemplateType]courier.EmailTemplate{ - courier.TypeRecoveryInvalid: template.NewRecoveryInvalid(conf, &template.RecoveryInvalidModel{To: "foo"}), - courier.TypeRecoveryValid: template.NewRecoveryValid(conf, &template.RecoveryValidModel{To: "bar", RecoveryURL: "http://foo.bar"}), - courier.TypeVerificationInvalid: template.NewVerificationInvalid(conf, &template.VerificationInvalidModel{To: "baz"}), - courier.TypeVerificationValid: template.NewVerificationValid(conf, &template.VerificationValidModel{To: "faz", VerificationURL: "http://bar.foo"}), - courier.TypeTestStub: template.NewTestStub(conf, &template.TestStubModel{To: "far", Subject: "test subject", Body: "test body"}), + courier.TypeRecoveryInvalid: email.NewRecoveryInvalid(conf, &email.RecoveryInvalidModel{To: "foo"}), + courier.TypeRecoveryValid: email.NewRecoveryValid(conf, &email.RecoveryValidModel{To: "bar", RecoveryURL: "http://foo.bar"}), + courier.TypeVerificationInvalid: email.NewVerificationInvalid(conf, &email.VerificationInvalidModel{To: "baz"}), + courier.TypeVerificationValid: email.NewVerificationValid(conf, &email.VerificationValidModel{To: "faz", VerificationURL: "http://bar.foo"}), + courier.TypeTestStub: email.NewTestEmailStub(conf, &email.TestStubModel{To: "far", Subject: "test subject", Body: "test body"}), } { t.Run(fmt.Sprintf("case=%s", tmplType), func(t *testing.T) { tmplData, err := json.Marshal(expectedTmpl) diff --git a/courier/message.go b/courier/message.go index 94b102781cb2..0641a0d49b85 100644 --- a/courier/message.go +++ b/courier/message.go @@ -4,9 +4,9 @@ import ( "context" "time" - "github.com/ory/kratos/corp" - "github.com/gofrs/uuid" + + "github.com/ory/kratos/corp" ) type MessageStatus int @@ -21,6 +21,7 @@ type MessageType int const ( MessageTypeEmail MessageType = iota + 1 + MessageTypePhone ) // swagger:ignore diff --git a/courier/sms.go b/courier/sms.go new file mode 100644 index 000000000000..5d0b28321a95 --- /dev/null +++ b/courier/sms.go @@ -0,0 +1,121 @@ +package courier + +import ( + "context" + "encoding/json" + "errors" + "net/http" + + "github.com/gofrs/uuid" + + "github.com/ory/kratos/request" +) + +var ErrSMSGateAddressUnset = errors.New("failed to dispatch message - sms gate address is not set") + +type sendSMSRequestBody struct { + To string + From string + Body string +} + +type smsClient struct { + *http.Client + Host string + RequestConfig json.RawMessage + + GetTemplateType func(t SMSTemplate) (TemplateType, error) + NewTemplateFromMessage func(c Config, msg Message) (SMSTemplate, error) +} + +func newSMS(ctx context.Context, deps Dependencies) *smsClient { + if !deps.CourierConfig(ctx).CourierSMSEnabled() { + deps.Logger().Error("messages will not be sent - no sms gate server address is set in config") + } + + return &smsClient{ + Client: &http.Client{}, + RequestConfig: deps.CourierConfig(ctx).CourierSMSRequestConfig(), + + GetTemplateType: SMSTemplateType, + NewTemplateFromMessage: NewSMSTemplateFromMessage, + } + +} + +func (c *courier) QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error) { + recipient, err := t.PhoneNumber() + if err != nil { + return uuid.Nil, err + } + + templateType, err := c.smsClient.GetTemplateType(t) + if err != nil { + return uuid.Nil, err + } + + templateData, err := json.Marshal(t) + if err != nil { + return uuid.Nil, err + } + + message := &Message{ + Status: MessageStatusQueued, + Type: MessageTypePhone, + Recipient: recipient, + TemplateType: templateType, + TemplateData: templateData, + } + if err := c.deps.CourierPersister().AddMessage(ctx, message); err != nil { + return uuid.Nil, err + } + + return message.ID, nil +} + +func (c *courier) dispatchSMS(ctx context.Context, msg Message) error { + tmpl, err := c.smsClient.NewTemplateFromMessage(c.deps.CourierConfig(ctx), msg) + if err != nil { + return err + } + + body, err := tmpl.SMSBody() + if err != nil { + return err + } + + builder, err := request.NewBuilder(c.smsClient.RequestConfig, c.deps.Logger()) + if err != nil { + return err + } + + req, err := builder.BuildRequest(&sendSMSRequestBody{ + To: msg.Recipient, + From: c.deps.CourierConfig(ctx).CourierSMSFrom(), + Body: body, + }) + if err != nil { + return err + } + + res, err := c.smsClient.Do(req) + if err != nil { + return err + } + + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return errors.New(http.StatusText(res.StatusCode)) + } + + if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to set the message status to "sent".`) + return err + } + + return nil +} diff --git a/courier/sms_templates.go b/courier/sms_templates.go new file mode 100644 index 000000000000..955630ff918b --- /dev/null +++ b/courier/sms_templates.go @@ -0,0 +1,45 @@ +package courier + +import ( + "encoding/json" + + "github.com/pkg/errors" + + "github.com/ory/kratos/courier/template/sms" +) + +type SMSTemplate interface { + json.Marshaler + SMSBody() (string, error) + PhoneNumber() (string, error) +} + +func SMSTemplateType(t SMSTemplate) (TemplateType, error) { + switch t.(type) { + case *sms.OTPMessage: + return TypeOTP, nil + case *sms.TestStub: + return TypeTestStub, nil + default: + return "", errors.Errorf("unexpected template type") + } +} + +func NewSMSTemplateFromMessage(c Config, m Message) (SMSTemplate, error) { + switch m.TemplateType { + case TypeOTP: + var t sms.OTPMessageModel + if err := json.Unmarshal(m.TemplateData, &t); err != nil { + return nil, err + } + return sms.NewOTPMessage(c, &t), nil + case TypeTestStub: + var t sms.TestStubModel + if err := json.Unmarshal(m.TemplateData, &t); err != nil { + return nil, err + } + return sms.NewTestStub(c, &t), nil + default: + return nil, errors.Errorf("received unexpected message template type: %s", m.TemplateType) + } +} diff --git a/courier/sms_templates_test.go b/courier/sms_templates_test.go new file mode 100644 index 000000000000..503e2f747a90 --- /dev/null +++ b/courier/sms_templates_test.go @@ -0,0 +1,57 @@ +package courier_test + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier" + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/internal" +) + +func TestSMSTemplateType(t *testing.T) { + for expectedType, tmpl := range map[courier.TemplateType]courier.SMSTemplate{ + courier.TypeOTP: &sms.OTPMessage{}, + courier.TypeTestStub: &sms.TestStub{}, + } { + t.Run(fmt.Sprintf("case=%s", expectedType), func(t *testing.T) { + actualType, err := courier.SMSTemplateType(tmpl) + require.NoError(t, err) + require.Equal(t, expectedType, actualType) + }) + } +} + +func TestNewSMSTemplateFromMessage(t *testing.T) { + conf := internal.NewConfigurationWithDefaults(t) + for tmplType, expectedTmpl := range map[courier.TemplateType]courier.SMSTemplate{ + courier.TypeOTP: sms.NewOTPMessage(conf, &sms.OTPMessageModel{To: "+12345678901"}), + courier.TypeTestStub: sms.NewTestStub(conf, &sms.TestStubModel{To: "+12345678901", Body: "test body"}), + } { + t.Run(fmt.Sprintf("case=%s", tmplType), func(t *testing.T) { + tmplData, err := json.Marshal(expectedTmpl) + require.NoError(t, err) + + m := courier.Message{TemplateType: tmplType, TemplateData: tmplData} + actualTmpl, err := courier.NewSMSTemplateFromMessage(conf, m) + require.NoError(t, err) + + require.IsType(t, expectedTmpl, actualTmpl) + + expectedRecipient, err := expectedTmpl.PhoneNumber() + require.NoError(t, err) + actualRecipient, err := actualTmpl.PhoneNumber() + require.NoError(t, err) + require.Equal(t, expectedRecipient, actualRecipient) + + expectedBody, err := expectedTmpl.SMSBody() + require.NoError(t, err) + actualBody, err := actualTmpl.SMSBody() + require.NoError(t, err) + require.Equal(t, expectedBody, actualBody) + }) + } +} diff --git a/courier/sms_test.go b/courier/sms_test.go new file mode 100644 index 000000000000..b705bf0490fa --- /dev/null +++ b/courier/sms_test.go @@ -0,0 +1,108 @@ +package courier_test + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/driver/config" + "github.com/ory/kratos/internal" + "github.com/ory/kratos/x" +) + +func TestQueueSMS(t *testing.T) { + expectedSender := "Kratos Test" + expectedSMS := []*sms.TestStubModel{ + { + To: "+12065550101", + Body: "test-sms-body-1", + }, + { + To: "+12065550102", + Body: "test-sms-body-2", + }, + } + + actual := make([]*sms.TestStubModel, 0, 2) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + type sendSMSRequestBody struct { + To string + From string + Body string + } + + rb, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var body sendSMSRequestBody + + err = json.Unmarshal(rb, &body) + require.NoError(t, err) + + assert.NotEmpty(t, r.Header["Authorization"]) + assert.Equal(t, "Basic bWU6MTIzNDU=", r.Header["Authorization"][0]) + + assert.Equal(t, body.From, expectedSender) + actual = append(actual, &sms.TestStubModel{ + To: body.To, + Body: body.Body, + }) + })) + + requestConfig := fmt.Sprintf(`{ + "url": "%s", + "method": "POST", + "body": "file://./stub/request.config.twilio.jsonnet", + "auth": { + "type": "basic_auth", + "config": { + "user": "me", + "password": "12345" + } + } + }`, srv.URL) + + conf, reg := internal.NewFastRegistryWithMocks(t) + conf.MustSet(config.ViperKeyCourierSMSRequestConfig, requestConfig) + conf.MustSet(config.ViperKeyCourierSMSFrom, expectedSender) + conf.MustSet(config.ViperKeyCourierSMSEnabled, true) + conf.MustSet(config.ViperKeyCourierSMTPURL, "http://foo.url") + reg.Logger().Level = logrus.TraceLevel + + ctx := context.Background() + + c := reg.Courier(ctx) + + ctx, cancel := context.WithCancel(ctx) + defer t.Cleanup(cancel) + + for _, message := range expectedSMS { + id, err := c.QueueSMS(ctx, sms.NewTestStub(conf, message)) + require.NoError(t, err) + x.RequireNotNilUUID(t, id) + } + + go func() { + require.NoError(t, c.Work(ctx)) + }() + + time.Sleep(time.Second) + for i, message := range actual { + expected := expectedSMS[i] + + assert.Equal(t, expected.To, message.To) + assert.Equal(t, fmt.Sprintf("stub sms body %s\n", expected.Body), message.Body) + } + + srv.Close() +} diff --git a/courier/smtp.go b/courier/smtp.go new file mode 100644 index 000000000000..367010d1b23f --- /dev/null +++ b/courier/smtp.go @@ -0,0 +1,187 @@ +package courier + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + + "github.com/ory/herodot" + gomail "github.com/ory/mail/v3" +) + +type smtpClient struct { + *gomail.Dialer + + GetTemplateType func(t EmailTemplate) (TemplateType, error) + NewTemplateFromMessage func(c Config, msg Message) (EmailTemplate, error) +} + +func newSMTP(ctx context.Context, deps Dependencies) *smtpClient { + uri := deps.CourierConfig(ctx).CourierSMTPURL() + + password, _ := uri.User.Password() + port, _ := strconv.ParseInt(uri.Port(), 10, 0) + + dialer := &gomail.Dialer{ + Host: uri.Hostname(), + Port: int(port), + Username: uri.User.Username(), + Password: password, + + Timeout: time.Second * 10, + RetryFailure: true, + } + + sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify")) + + // SMTP schemes + // smtp: smtp clear text (with uri parameter) or with StartTLS (enforced by default) + // smtps: smtp with implicit TLS (recommended way in 2021 to avoid StartTLS downgrade attacks + // and defaulting to fully-encrypted protocols https://datatracker.ietf.org/doc/html/rfc8314) + switch uri.Scheme { + case "smtp": + // Enforcing StartTLS by default for security best practices (config review, etc.) + skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls")) + if !skipStartTLS { + // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. + dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} + // Enforcing StartTLS + dialer.StartTLSPolicy = gomail.MandatoryStartTLS + } + case "smtps": + // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. + dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} + dialer.SSL = true + } + + return &smtpClient{ + Dialer: dialer, + + GetTemplateType: GetEmailTemplateType, + NewTemplateFromMessage: NewEmailTemplateFromMessage, + } +} + +func (c *courier) SmtpDialer() *gomail.Dialer { + return c.smtpClient.Dialer +} + +func (c *courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) { + recipient, err := t.EmailRecipient() + if err != nil { + return uuid.Nil, err + } + + subject, err := t.EmailSubject() + if err != nil { + return uuid.Nil, err + } + + bodyPlaintext, err := t.EmailBodyPlaintext() + if err != nil { + return uuid.Nil, err + } + + templateType, err := c.smtpClient.GetTemplateType(t) + if err != nil { + return uuid.Nil, err + } + + templateData, err := json.Marshal(t) + if err != nil { + return uuid.Nil, err + } + + message := &Message{ + Status: MessageStatusQueued, + Type: MessageTypeEmail, + Recipient: recipient, + Body: bodyPlaintext, + Subject: subject, + TemplateType: templateType, + TemplateData: templateData, + } + + if err := c.deps.CourierPersister().AddMessage(ctx, message); err != nil { + return uuid.Nil, err + } + + return message.ID, nil +} + +func (c *courier) dispatchEmail(ctx context.Context, msg Message) error { + if c.smtpClient.Host == "" { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an email but courier.smtp_url is not set!")) + } + + from := c.deps.CourierConfig(ctx).CourierSMTPFrom() + fromName := c.deps.CourierConfig(ctx).CourierSMTPFromName() + + gm := gomail.NewMessage() + if fromName == "" { + gm.SetHeader("From", from) + } else { + gm.SetAddressHeader("From", from, fromName) + } + + gm.SetHeader("To", msg.Recipient) + gm.SetHeader("Subject", msg.Subject) + + headers := c.deps.CourierConfig(ctx).CourierSMTPHeaders() + for k, v := range headers { + gm.SetHeader(k, v) + } + + gm.SetBody("text/plain", msg.Body) + + tmpl, err := c.smtpClient.NewTemplateFromMessage(c.deps.CourierConfig(ctx), msg) + if err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to get email template from message.`) + } else { + htmlBody, err := tmpl.EmailBody() + if err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to get email body from template.`) + } else { + gm.AddAlternative("text/html", htmlBody) + } + } + + if err := c.smtpClient.DialAndSend(ctx, gm); err != nil { + c.deps.Logger(). + WithError(err). + WithField("smtp_server", fmt.Sprintf("%s:%d", c.smtpClient.Host, c.smtpClient.Port)). + WithField("smtp_ssl_enabled", c.smtpClient.SSL). + // WithField("email_to", msg.Recipient). + WithField("message_from", from). + Error("Unable to send email using SMTP connection.") + return errors.WithStack(err) + } + + if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to set the message status to "sent".`) + return err + } + + c.deps.Logger(). + WithField("message_id", msg.ID). + WithField("message_type", msg.Type). + WithField("message_template_type", msg.TemplateType). + WithField("message_subject", msg.Subject). + Debug("Courier sent out message.") + return nil +} diff --git a/courier/smtp_test.go b/courier/smtp_test.go new file mode 100644 index 000000000000..55c13423a7cb --- /dev/null +++ b/courier/smtp_test.go @@ -0,0 +1,155 @@ +package courier_test + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "testing" + "time" + + templates "github.com/ory/kratos/courier/template/email" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + + "github.com/ory/kratos/courier" + "github.com/ory/kratos/driver/config" + "github.com/ory/kratos/internal" + "github.com/ory/kratos/x" + gomail "github.com/ory/mail/v3" +) + +func TestNewSMTP(t *testing.T) { + ctx := context.Background() + + setupConfig := func(stringURL string) courier.Courier { + conf, reg := internal.NewFastRegistryWithMocks(t) + conf.MustSet(config.ViperKeyCourierSMTPURL, stringURL) + + t.Logf("SMTP URL: %s", conf.CourierSMTPURL().String()) + + return courier.NewCourier(ctx, reg) + } + + if testing.Short() { + t.SkipNow() + } + + //Should enforce StartTLS => dialer.StartTLSPolicy = gomail.MandatoryStartTLS and dialer.SSL = false + smtp := setupConfig("smtp://foo:bar@my-server:1234/") + assert.Equal(t, smtp.SmtpDialer().StartTLSPolicy, gomail.MandatoryStartTLS, "StartTLS not enforced") + assert.Equal(t, smtp.SmtpDialer().SSL, false, "Implicit TLS should not be enabled") + + //Should enforce TLS => dialer.SSL = true + smtp = setupConfig("smtps://foo:bar@my-server:1234/") + assert.Equal(t, smtp.SmtpDialer().SSL, true, "Implicit TLS should be enabled") + + //Should allow cleartext => dialer.StartTLSPolicy = gomail.OpportunisticStartTLS and dialer.SSL = false + smtp = setupConfig("smtp://foo:bar@my-server:1234/?disable_starttls=true") + assert.Equal(t, smtp.SmtpDialer().StartTLSPolicy, gomail.OpportunisticStartTLS, "StartTLS is enforced") + assert.Equal(t, smtp.SmtpDialer().SSL, false, "Implicit TLS should not be enabled") +} + +func TestQueueEmail(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + + smtp, api, err := x.RunTestSMTP() + require.NoError(t, err) + t.Logf("SMTP URL: %s", smtp) + t.Logf("API URL: %s", api) + + ctx := context.Background() + + conf, reg := internal.NewRegistryDefaultWithDSN(t, "") + conf.MustSet(config.ViperKeyCourierSMTPURL, smtp) + conf.MustSet(config.ViperKeyCourierSMTPFrom, "test-stub@ory.sh") + reg.Logger().Level = logrus.TraceLevel + + c := reg.Courier(ctx) //??? + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + id, err := c.QueueEmail(ctx, templates.NewTestEmailStub(conf, &templates.TestStubModel{ + To: "test-recipient-1@example.org", + Subject: "test-subject-1", + Body: "test-body-1", + })) + require.NoError(t, err) + x.RequireNotNilUUID(t, id) + + id, err = c.QueueEmail(ctx, templates.NewTestEmailStub(conf, &templates.TestStubModel{ + To: "test-recipient-2@example.org", + Subject: "test-subject-2", + Body: "test-body-2", + })) + require.NoError(t, err) + x.RequireNotNilUUID(t, id) + + // The third email contains a sender name and custom headers + conf.MustSet(config.ViperKeyCourierSMTPFromName, "Bob") + conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header1", "foo") + conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header2", "bar") + customerHeaders := conf.CourierSMTPHeaders() + require.Len(t, customerHeaders, 2) + id, err = c.QueueEmail(ctx, templates.NewTestEmailStub(conf, &templates.TestStubModel{ + To: "test-recipient-3@example.org", + Subject: "test-subject-3", + Body: "test-body-3", + })) + require.NoError(t, err) + x.RequireNotNilUUID(t, id) + + go func() { + require.NoError(t, c.Work(ctx)) + }() + + var body []byte + for k := 0; k < 30; k++ { + time.Sleep(time.Second) + err = func() error { + res, err := http.Get(api + "/api/v2/messages") + if err != nil { + return err + } + + defer res.Body.Close() + body, err = ioutil.ReadAll(res.Body) + if err != nil { + return err + } + + if http.StatusOK != res.StatusCode { + return errors.Errorf("expected status code 200 but got %d with body: %s", res.StatusCode, body) + } + + if total := gjson.GetBytes(body, "total").Int(); total != 3 { + return errors.Errorf("expected to have delivered at least 3 messages but got count %d with body: %s", total, body) + } + + return nil + }() + if err == nil { + break + } + } + require.NoError(t, err) + + for k := 1; k <= 3; k++ { + assert.Contains(t, string(body), fmt.Sprintf("test-subject-%d", k)) + assert.Contains(t, string(body), fmt.Sprintf("test-body-%d", k)) + assert.Contains(t, string(body), fmt.Sprintf("test-recipient-%d@example.org", k)) + assert.Contains(t, string(body), "test-stub@ory.sh") + } + + // Assertion for the third email with sender name and headers + assert.Contains(t, string(body), "Bob") + assert.Contains(t, string(body), `"test-stub-header1":["foo"]`) + assert.Contains(t, string(body), `"test-stub-header2":["bar"]`) +} diff --git a/courier/stub/request.config.twilio.jsonnet b/courier/stub/request.config.twilio.jsonnet new file mode 100644 index 000000000000..93752e145035 --- /dev/null +++ b/courier/stub/request.config.twilio.jsonnet @@ -0,0 +1,5 @@ +function(ctx) { + from: ctx.From, + to: ctx.To, + body: ctx.Body +} diff --git a/courier/template/courier/builtin/templates/otp/sms.body.gotmpl b/courier/template/courier/builtin/templates/otp/sms.body.gotmpl new file mode 100644 index 000000000000..a630a83b82db --- /dev/null +++ b/courier/template/courier/builtin/templates/otp/sms.body.gotmpl @@ -0,0 +1,3 @@ +Hi, please verify your account using following code: + +{{ .Code }} diff --git a/courier/template/courier/builtin/templates/otp/test_stub/sms.body.gotmpl b/courier/template/courier/builtin/templates/otp/test_stub/sms.body.gotmpl new file mode 100644 index 000000000000..a37e4640152d --- /dev/null +++ b/courier/template/courier/builtin/templates/otp/test_stub/sms.body.gotmpl @@ -0,0 +1 @@ +stub sms body {{ .Body }} diff --git a/courier/template/email/recovery_invalid.go b/courier/template/email/recovery_invalid.go new file mode 100644 index 000000000000..09e3037b9617 --- /dev/null +++ b/courier/template/email/recovery_invalid.go @@ -0,0 +1,41 @@ +package email + +import ( + "encoding/json" + + "github.com/ory/kratos/courier/template" +) + +type ( + RecoveryInvalid struct { + c template.Config + m *RecoveryInvalidModel + } + RecoveryInvalidModel struct { + To string + } +) + +func NewRecoveryInvalid(c template.Config, m *RecoveryInvalidModel) *RecoveryInvalid { + return &RecoveryInvalid{c: c, m: m} +} + +func (t *RecoveryInvalid) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *RecoveryInvalid) EmailSubject() (string, error) { + return template.LoadText(t.c.CourierTemplatesRoot(), "recovery/invalid/email.subject.gotmpl", "recovery/invalid/email.subject*", t.m) +} + +func (t *RecoveryInvalid) EmailBody() (string, error) { + return template.LoadHTML(t.c.CourierTemplatesRoot(), "recovery/invalid/email.body.gotmpl", "recovery/invalid/email.body*", t.m) +} + +func (t *RecoveryInvalid) EmailBodyPlaintext() (string, error) { + return template.LoadText(t.c.CourierTemplatesRoot(), "recovery/invalid/email.body.plaintext.gotmpl", "recovery/invalid/email.body.plaintext*", t.m) +} + +func (t *RecoveryInvalid) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/recovery_invalid_test.go b/courier/template/email/recovery_invalid_test.go similarity index 74% rename from courier/template/recovery_invalid_test.go rename to courier/template/email/recovery_invalid_test.go index 021efc100a8e..74c5876ddc91 100644 --- a/courier/template/recovery_invalid_test.go +++ b/courier/template/email/recovery_invalid_test.go @@ -1,18 +1,19 @@ -package template_test +package email_test import ( "testing" + "github.com/ory/kratos/courier/template/email" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/kratos/courier/template" "github.com/ory/kratos/internal" ) func TestRecoverInvalid(t *testing.T) { conf, _ := internal.NewFastRegistryWithMocks(t) - tpl := template.NewRecoveryInvalid(conf, &template.RecoveryInvalidModel{}) + tpl := email.NewRecoveryInvalid(conf, &email.RecoveryInvalidModel{}) rendered, err := tpl.EmailBody() require.NoError(t, err) diff --git a/courier/template/recovery_valid.go b/courier/template/email/recovery_valid.go similarity index 51% rename from courier/template/recovery_valid.go rename to courier/template/email/recovery_valid.go index 47eed72534a3..c7a8831742cf 100644 --- a/courier/template/recovery_valid.go +++ b/courier/template/email/recovery_valid.go @@ -1,12 +1,14 @@ -package template +package email import ( "encoding/json" + + "github.com/ory/kratos/courier/template" ) type ( RecoveryValid struct { - c TemplateConfig + c template.Config m *RecoveryValidModel } RecoveryValidModel struct { @@ -16,7 +18,7 @@ type ( } ) -func NewRecoveryValid(c TemplateConfig, m *RecoveryValidModel) *RecoveryValid { +func NewRecoveryValid(c template.Config, m *RecoveryValidModel) *RecoveryValid { return &RecoveryValid{c: c, m: m} } @@ -25,15 +27,15 @@ func (t *RecoveryValid) EmailRecipient() (string, error) { } func (t *RecoveryValid) EmailSubject() (string, error) { - return LoadTextTemplate(t.c.CourierTemplatesRoot(), "recovery/valid/email.subject.gotmpl", "recovery/valid/email.subject*", t.m) + return template.LoadText(t.c.CourierTemplatesRoot(), "recovery/valid/email.subject.gotmpl", "recovery/valid/email.subject*", t.m) } func (t *RecoveryValid) EmailBody() (string, error) { - return LoadHTMLTemplate(t.c.CourierTemplatesRoot(), "recovery/valid/email.body.gotmpl", "recovery/valid/email.body*", t.m) + return template.LoadHTML(t.c.CourierTemplatesRoot(), "recovery/valid/email.body.gotmpl", "recovery/valid/email.body*", t.m) } func (t *RecoveryValid) EmailBodyPlaintext() (string, error) { - return LoadTextTemplate(t.c.CourierTemplatesRoot(), "recovery/valid/email.body.plaintext.gotmpl", "recovery/valid/email.body.plaintext*", t.m) + return template.LoadText(t.c.CourierTemplatesRoot(), "recovery/valid/email.body.plaintext.gotmpl", "recovery/valid/email.body.plaintext*", t.m) } func (t *RecoveryValid) MarshalJSON() ([]byte, error) { diff --git a/courier/template/recovery_valid_test.go b/courier/template/email/recovery_valid_test.go similarity index 75% rename from courier/template/recovery_valid_test.go rename to courier/template/email/recovery_valid_test.go index 09d355e14555..a191789b9997 100644 --- a/courier/template/recovery_valid_test.go +++ b/courier/template/email/recovery_valid_test.go @@ -1,4 +1,4 @@ -package template_test +package email_test import ( "testing" @@ -6,13 +6,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/kratos/courier/template" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/internal" ) func TestRecoverValid(t *testing.T) { conf, _ := internal.NewFastRegistryWithMocks(t) - tpl := template.NewRecoveryValid(conf, &template.RecoveryValidModel{}) + tpl := email.NewRecoveryValid(conf, &email.RecoveryValidModel{}) rendered, err := tpl.EmailBody() require.NoError(t, err) diff --git a/courier/template/email/stub.go b/courier/template/email/stub.go new file mode 100644 index 000000000000..434b47a390cf --- /dev/null +++ b/courier/template/email/stub.go @@ -0,0 +1,43 @@ +package email + +import ( + "encoding/json" + + "github.com/ory/kratos/courier/template" +) + +type ( + TestStub struct { + c template.Config + m *TestStubModel + } + TestStubModel struct { + To string + Subject string + Body string + } +) + +func NewTestEmailStub(c template.Config, m *TestStubModel) *TestStub { + return &TestStub{c: c, m: m} +} + +func (t *TestStub) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *TestStub) EmailSubject() (string, error) { + return template.LoadText(t.c.CourierTemplatesRoot(), "test_stub/email.subject.gotmpl", "test_stub/email.subject*", t.m) +} + +func (t *TestStub) EmailBody() (string, error) { + return template.LoadHTML(t.c.CourierTemplatesRoot(), "test_stub/email.body.gotmpl", "test_stub/email.body*", t.m) +} + +func (t *TestStub) EmailBodyPlaintext() (string, error) { + return template.LoadText(t.c.CourierTemplatesRoot(), "test_stub/email.body.plaintext.gotmpl", "test_stub/email.body.plaintext*", t.m) +} + +func (t *TestStub) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/email/verification_invalid.go b/courier/template/email/verification_invalid.go new file mode 100644 index 000000000000..a9092c1b15e2 --- /dev/null +++ b/courier/template/email/verification_invalid.go @@ -0,0 +1,41 @@ +package email + +import ( + "encoding/json" + + "github.com/ory/kratos/courier/template" +) + +type ( + VerificationInvalid struct { + c template.Config + m *VerificationInvalidModel + } + VerificationInvalidModel struct { + To string + } +) + +func NewVerificationInvalid(c template.Config, m *VerificationInvalidModel) *VerificationInvalid { + return &VerificationInvalid{c: c, m: m} +} + +func (t *VerificationInvalid) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *VerificationInvalid) EmailSubject() (string, error) { + return template.LoadText(t.c.CourierTemplatesRoot(), "verification/invalid/email.subject.gotmpl", "verification/invalid/email.subject*", t.m) +} + +func (t *VerificationInvalid) EmailBody() (string, error) { + return template.LoadHTML(t.c.CourierTemplatesRoot(), "verification/invalid/email.body.gotmpl", "verification/invalid/email.body*", t.m) +} + +func (t *VerificationInvalid) EmailBodyPlaintext() (string, error) { + return template.LoadText(t.c.CourierTemplatesRoot(), "verification/invalid/email.body.plaintext.gotmpl", "verification/invalid/email.body.plaintext*", t.m) +} + +func (t *VerificationInvalid) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/verification_invalid_test.go b/courier/template/email/verification_invalid_test.go similarity index 73% rename from courier/template/verification_invalid_test.go rename to courier/template/email/verification_invalid_test.go index 5d77cae21fb1..ecbf83211056 100644 --- a/courier/template/verification_invalid_test.go +++ b/courier/template/email/verification_invalid_test.go @@ -1,18 +1,19 @@ -package template_test +package email_test import ( "testing" + "github.com/ory/kratos/courier/template/email" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/kratos/courier/template" "github.com/ory/kratos/internal" ) func TestVerifyInvalid(t *testing.T) { conf, _ := internal.NewFastRegistryWithMocks(t) - tpl := template.NewVerificationInvalid(conf, &template.VerificationInvalidModel{}) + tpl := email.NewVerificationInvalid(conf, &email.VerificationInvalidModel{}) rendered, err := tpl.EmailBody() require.NoError(t, err) diff --git a/courier/template/verification_valid.go b/courier/template/email/verification_valid.go similarity index 51% rename from courier/template/verification_valid.go rename to courier/template/email/verification_valid.go index f408ad9bcc29..87a0c99ecfd9 100644 --- a/courier/template/verification_valid.go +++ b/courier/template/email/verification_valid.go @@ -1,12 +1,14 @@ -package template +package email import ( "encoding/json" + + "github.com/ory/kratos/courier/template" ) type ( VerificationValid struct { - c TemplateConfig + c template.Config m *VerificationValidModel } VerificationValidModel struct { @@ -16,7 +18,7 @@ type ( } ) -func NewVerificationValid(c TemplateConfig, m *VerificationValidModel) *VerificationValid { +func NewVerificationValid(c template.Config, m *VerificationValidModel) *VerificationValid { return &VerificationValid{c: c, m: m} } @@ -25,15 +27,15 @@ func (t *VerificationValid) EmailRecipient() (string, error) { } func (t *VerificationValid) EmailSubject() (string, error) { - return LoadTextTemplate(t.c.CourierTemplatesRoot(), "verification/valid/email.subject.gotmpl", "verification/valid/email.subject*", t.m) + return template.LoadText(t.c.CourierTemplatesRoot(), "verification/valid/email.subject.gotmpl", "verification/valid/email.subject*", t.m) } func (t *VerificationValid) EmailBody() (string, error) { - return LoadHTMLTemplate(t.c.CourierTemplatesRoot(), "verification/valid/email.body.gotmpl", "verification/valid/email.body*", t.m) + return template.LoadHTML(t.c.CourierTemplatesRoot(), "verification/valid/email.body.gotmpl", "verification/valid/email.body*", t.m) } func (t *VerificationValid) EmailBodyPlaintext() (string, error) { - return LoadTextTemplate(t.c.CourierTemplatesRoot(), "verification/valid/email.body.plaintext.gotmpl", "verification/valid/email.body.plaintext*", t.m) + return template.LoadText(t.c.CourierTemplatesRoot(), "verification/valid/email.body.plaintext.gotmpl", "verification/valid/email.body.plaintext*", t.m) } func (t *VerificationValid) MarshalJSON() ([]byte, error) { diff --git a/courier/template/verification_valid_test.go b/courier/template/email/verification_valid_test.go similarity index 73% rename from courier/template/verification_valid_test.go rename to courier/template/email/verification_valid_test.go index f80fbf4cec45..ec3706fbc39e 100644 --- a/courier/template/verification_valid_test.go +++ b/courier/template/email/verification_valid_test.go @@ -1,4 +1,4 @@ -package template_test +package email_test import ( "testing" @@ -6,13 +6,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/ory/kratos/courier/template" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/internal" ) func TestVerifyValid(t *testing.T) { conf, _ := internal.NewFastRegistryWithMocks(t) - tpl := template.NewVerificationValid(conf, &template.VerificationValidModel{}) + tpl := email.NewVerificationValid(conf, &email.VerificationValidModel{}) rendered, err := tpl.EmailBody() require.NoError(t, err) diff --git a/courier/template/load_template.go b/courier/template/load_template.go index 8f025a74121a..a10d8e0cf563 100644 --- a/courier/template/load_template.go +++ b/courier/template/load_template.go @@ -24,6 +24,30 @@ type Template interface { Execute(wr io.Writer, data interface{}) error } +func LoadText(osdir, name, pattern string, model interface{}) (string, error) { + t, err := loadTemplate(osdir, name, pattern, false) + if err != nil { + return "", err + } + var b bytes.Buffer + if err := t.Execute(&b, model); err != nil { + return "", err + } + return b.String(), nil +} + +func LoadHTML(osdir, name, pattern string, model interface{}) (string, error) { + t, err := loadTemplate(osdir, name, pattern, true) + if err != nil { + return "", err + } + var b bytes.Buffer + if err := t.Execute(&b, model); err != nil { + return "", err + } + return b.String(), nil +} + func loadBuiltInTemplate(osdir, name string, html bool) (Template, error) { if t, found := cache.Get(name); found { return t.(Template), nil @@ -104,27 +128,3 @@ func loadTemplate(osdir, name, pattern string, html bool) (Template, error) { _ = cache.Add(name, tpl) return tpl, nil } - -func LoadTextTemplate(osdir, name, pattern string, model interface{}) (string, error) { - t, err := loadTemplate(osdir, name, pattern, false) - if err != nil { - return "", err - } - var b bytes.Buffer - if err := t.Execute(&b, model); err != nil { - return "", err - } - return b.String(), nil -} - -func LoadHTMLTemplate(osdir, name, pattern string, model interface{}) (string, error) { - t, err := loadTemplate(osdir, name, pattern, true) - if err != nil { - return "", err - } - var b bytes.Buffer - if err := t.Execute(&b, model); err != nil { - return "", err - } - return b.String(), nil -} diff --git a/courier/template/load_template_test.go b/courier/template/load_template_test.go index fa17c77a1d42..b90b2371f530 100644 --- a/courier/template/load_template_test.go +++ b/courier/template/load_template_test.go @@ -14,13 +14,13 @@ import ( func TestLoadTextTemplate(t *testing.T) { var executeTextTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string { - tp, err := LoadTextTemplate(dir, name, pattern, model) + tp, err := LoadText(dir, name, pattern, model) require.NoError(t, err) return tp } var executeHTMLTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string { - tp, err := LoadHTMLTemplate(dir, name, pattern, model) + tp, err := LoadHTML(dir, name, pattern, model) require.NoError(t, err) return tp } diff --git a/courier/template/recovery_invalid.go b/courier/template/recovery_invalid.go deleted file mode 100644 index e938bc9947a5..000000000000 --- a/courier/template/recovery_invalid.go +++ /dev/null @@ -1,39 +0,0 @@ -package template - -import ( - "encoding/json" -) - -type ( - RecoveryInvalid struct { - c TemplateConfig - m *RecoveryInvalidModel - } - RecoveryInvalidModel struct { - To string - } -) - -func NewRecoveryInvalid(c TemplateConfig, m *RecoveryInvalidModel) *RecoveryInvalid { - return &RecoveryInvalid{c: c, m: m} -} - -func (t *RecoveryInvalid) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *RecoveryInvalid) EmailSubject() (string, error) { - return LoadTextTemplate(t.c.CourierTemplatesRoot(), "recovery/invalid/email.subject.gotmpl", "recovery/invalid/email.subject*", t.m) -} - -func (t *RecoveryInvalid) EmailBody() (string, error) { - return LoadHTMLTemplate(t.c.CourierTemplatesRoot(), "recovery/invalid/email.body.gotmpl", "recovery/invalid/email.body*", t.m) -} - -func (t *RecoveryInvalid) EmailBodyPlaintext() (string, error) { - return LoadTextTemplate(t.c.CourierTemplatesRoot(), "recovery/invalid/email.body.plaintext.gotmpl", "recovery/invalid/email.body.plaintext*", t.m) -} - -func (t *RecoveryInvalid) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/courier/template/sms/otp.go b/courier/template/sms/otp.go new file mode 100644 index 000000000000..d55f31a93b94 --- /dev/null +++ b/courier/template/sms/otp.go @@ -0,0 +1,35 @@ +package sms + +import ( + "encoding/json" + + "github.com/ory/kratos/courier/template" +) + +type ( + OTPMessage struct { + c template.Config + m *OTPMessageModel + } + + OTPMessageModel struct { + To string + Code string + } +) + +func NewOTPMessage(c template.Config, m *OTPMessageModel) *OTPMessage { + return &OTPMessage{c: c, m: m} +} + +func (t *OTPMessage) PhoneNumber() (string, error) { + return t.m.To, nil +} + +func (t *OTPMessage) SMSBody() (string, error) { + return template.LoadText(t.c.CourierTemplatesRoot(), "otp/sms.body.gotmpl", "otp/sms.body*", t.m) +} + +func (t *OTPMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/sms/otp_test.go b/courier/template/sms/otp_test.go new file mode 100644 index 000000000000..51109e465c88 --- /dev/null +++ b/courier/template/sms/otp_test.go @@ -0,0 +1,33 @@ +package sms_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/internal" +) + +func TestNewOTPMessage(t *testing.T) { + conf, _ := internal.NewFastRegistryWithMocks(t) + + const ( + expectedPhone = "+12345678901" + otp = "012345" + ) + + tpl := sms.NewOTPMessage(conf, &sms.OTPMessageModel{To: expectedPhone, Code: otp}) + + expectedBody := fmt.Sprintf("Hi, please verify your account using following code:\n\n%s\n", otp) + + actualBody, err := tpl.SMSBody() + require.NoError(t, err) + assert.Equal(t, expectedBody, actualBody) + + actualPhone, err := tpl.PhoneNumber() + require.NoError(t, err) + assert.Equal(t, expectedPhone, actualPhone) +} diff --git a/courier/template/sms/stub.go b/courier/template/sms/stub.go new file mode 100644 index 000000000000..89fb4370dd16 --- /dev/null +++ b/courier/template/sms/stub.go @@ -0,0 +1,35 @@ +package sms + +import ( + "encoding/json" + + "github.com/ory/kratos/courier/template" +) + +type ( + TestStub struct { + c template.Config + m *TestStubModel + } + + TestStubModel struct { + To string + Body string + } +) + +func NewTestStub(c template.Config, m *TestStubModel) *TestStub { + return &TestStub{c: c, m: m} +} + +func (t *TestStub) PhoneNumber() (string, error) { + return t.m.To, nil +} + +func (t *TestStub) SMSBody() (string, error) { + return template.LoadText(t.c.CourierTemplatesRoot(), "otp/test_stub/sms.body.gotmpl", "otp/test_stub/sms.body*", t.m) +} + +func (t *TestStub) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/sms/stub_test.go b/courier/template/sms/stub_test.go new file mode 100644 index 000000000000..95432b3feda9 --- /dev/null +++ b/courier/template/sms/stub_test.go @@ -0,0 +1,30 @@ +package sms_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/internal" +) + +func TestNewTestStub(t *testing.T) { + conf, _ := internal.NewFastRegistryWithMocks(t) + + const ( + expectedPhone = "+12345678901" + expectedBody = "test sms" + ) + + tpl := sms.NewTestStub(conf, &sms.TestStubModel{To: expectedPhone, Body: expectedBody}) + + actualBody, err := tpl.SMSBody() + require.NoError(t, err) + assert.Equal(t, "stub sms body test sms\n", actualBody) + + actualPhone, err := tpl.PhoneNumber() + require.NoError(t, err) + assert.Equal(t, expectedPhone, actualPhone) +} diff --git a/courier/template/stub.go b/courier/template/stub.go deleted file mode 100644 index 48388cbb4114..000000000000 --- a/courier/template/stub.go +++ /dev/null @@ -1,40 +0,0 @@ -package template - -import ( - "encoding/json" -) - -type TestStub struct { - c TemplateConfig - m *TestStubModel -} - -type TestStubModel struct { - To string - Subject string - Body string -} - -func NewTestStub(c TemplateConfig, m *TestStubModel) *TestStub { - return &TestStub{c: c, m: m} -} - -func (t *TestStub) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *TestStub) EmailSubject() (string, error) { - return LoadTextTemplate(t.c.CourierTemplatesRoot(), "test_stub/email.subject.gotmpl", "test_stub/email.subject*", t.m) -} - -func (t *TestStub) EmailBody() (string, error) { - return LoadHTMLTemplate(t.c.CourierTemplatesRoot(), "test_stub/email.body.gotmpl", "test_stub/email.body*", t.m) -} - -func (t *TestStub) EmailBodyPlaintext() (string, error) { - return LoadTextTemplate(t.c.CourierTemplatesRoot(), "test_stub/email.body.plaintext.gotmpl", "test_stub/email.body.plaintext*", t.m) -} - -func (t *TestStub) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/courier/template/template.go b/courier/template/template.go index 0486356fb0e8..0329987e98a7 100644 --- a/courier/template/template.go +++ b/courier/template/template.go @@ -1,7 +1,7 @@ package template type ( - TemplateConfig interface { + Config interface { CourierTemplatesRoot() string } ) diff --git a/courier/template/verification_invalid.go b/courier/template/verification_invalid.go deleted file mode 100644 index a145cbf66730..000000000000 --- a/courier/template/verification_invalid.go +++ /dev/null @@ -1,39 +0,0 @@ -package template - -import ( - "encoding/json" -) - -type ( - VerificationInvalid struct { - c TemplateConfig - m *VerificationInvalidModel - } - VerificationInvalidModel struct { - To string - } -) - -func NewVerificationInvalid(c TemplateConfig, m *VerificationInvalidModel) *VerificationInvalid { - return &VerificationInvalid{c: c, m: m} -} - -func (t *VerificationInvalid) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *VerificationInvalid) EmailSubject() (string, error) { - return LoadTextTemplate(t.c.CourierTemplatesRoot(), "verification/invalid/email.subject.gotmpl", "verification/invalid/email.subject*", t.m) -} - -func (t *VerificationInvalid) EmailBody() (string, error) { - return LoadHTMLTemplate(t.c.CourierTemplatesRoot(), "verification/invalid/email.body.gotmpl", "verification/invalid/email.body*", t.m) -} - -func (t *VerificationInvalid) EmailBodyPlaintext() (string, error) { - return LoadTextTemplate(t.c.CourierTemplatesRoot(), "verification/invalid/email.body.plaintext.gotmpl", "verification/invalid/email.body.plaintext*", t.m) -} - -func (t *VerificationInvalid) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/docs/docs/concepts/email-sms.md b/docs/docs/concepts/email-sms.md index 1ceed215ab02..2353746f536a 100644 --- a/docs/docs/concepts/email-sms.md +++ b/docs/docs/concepts/email-sms.md @@ -27,7 +27,7 @@ the `--watch-courier` flag to your `kratos serve` command, as outlined in the ### Multi-instance -If you're running multiple instances of Kratos (eg replicated Kubernetes +If you're running multiple instances of Kratos (e.g. replicated Kubernetes deployment), you need to run the mail courier as a separate singleton job. The courier can be started with the `kratos courier watch` command ([CLI docs](../cli/kratos-courier.md)). @@ -59,7 +59,8 @@ courier: ### Sender Address and Template Customization -You can customize the sender address and email templates. +You can customize the sender address and email templates by overriding path to +the templates folder. See more about custom templates in templates section. ```yaml title="path/to/my/kratos/config.yml" # $ kratos -c path/to/my/kratos/config.yml serve @@ -94,6 +95,63 @@ courier: template_override_path: /conf/courier-templates ``` +### Custom Headers + +You can configure custom SMTP headers. For example, if integrating with AWS SES +SMTP interface, the headers can be configured for cross-account sending: + +```yaml title="path/to/my/kratos/config.yml" +# $ kratos -c path/to/my/kratos/config.yml serve +courier: + smtp: + headers: + X-SES-SOURCE-ARN: arn:aws:ses:us-west-2:123456789012:identity/example.com + X-SES-FROM-ARN: arn:aws:ses:us-west-2:123456789012:identity/example.com + X-SES-RETURN-PATH-ARN: arn:aws:ses:us-west-2:123456789012:identity/example.com +``` + +## Sending SMS + +For sending SMS Ory Kratos requires an external SMS gateway, which must be able +to satisfy the HTTP contract. The address of the SMS gateway endpoint must be +set in the configuration file. Please note that it needs to be absolute, start +with http:// or https:// scheme, and include path part - e.g. +"https://api.sender.com/v1/message". + +```yaml title="path/to/my/kratos/config.yml" +# $ kratos -c path/to/my/kratos/config.yml serve +courier: + sms: + host: https://api.sender.com/v1/message +``` + +### HTTP API contract + +Kratos will send a POST request to the endpoint set in the config for phone +verification and recovery (note that recovery using SMS is currently not +implemented). Post request body has urlencoded format, and contains three +parameters - "To", "From", and "Body". To - phone number of the recipient. +From - configurable sender name. Body - SMS text. It will contain a link or OTP +code for verification(in progress). + +Authorization with third party SMS gate via access tokens are not implemented +yet. + +### Sender name + +The recipient of an email will see this as the sender’s address. You can +customize the sender name for SMS by overriding the "from_name" config property. +Default sender name is equal to "Kratos". + +```yaml title="path/to/my/kratos/config.yml" +# $ kratos -c path/to/my/kratos/config.yml serve +courier: + sms: + from_name: 'Your Org Name' +``` + +## Message templates + Ory Kratos comes with built-in templates. If you wish to define your own, custom templates, you should define `template_override_path`, as shown above, to indicate where your custom templates are located. This will become the @@ -200,23 +258,3 @@ As indicated by the example, we need a root template, which is the the following pattern: `email.body*`. You can also see that the `Identity` of the user is available in all templates, and that you can use Sprig functions also in the nested templates. - -### Custom Headers - -You can configure custom SMTP headers. For example, if integrating with AWS SES -SMTP interface, the headers can be configured for cross-account sending: - -```yaml title="path/to/my/kratos/config.yml" -# $ kratos -c path/to/my/kratos/config.yml serve -courier: - smtp: - headers: - X-SES-SOURCE-ARN: arn:aws:ses:us-west-2:123456789012:identity/example.com - X-SES-FROM-ARN: arn:aws:ses:us-west-2:123456789012:identity/example.com - X-SES-RETURN-PATH-ARN: arn:aws:ses:us-west-2:123456789012:identity/example.com -``` - -## Sending SMS - -The Sending SMS feature is not supported at present. It will be available in a -future version of Ory Kratos. diff --git a/driver/config/config.go b/driver/config/config.go index fc3dd0be3f45..b0d8ebfb6595 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -62,6 +62,9 @@ const ( ViperKeyCourierSMTPFrom = "courier.smtp.from_address" ViperKeyCourierSMTPFromName = "courier.smtp.from_name" ViperKeyCourierSMTPHeaders = "courier.smtp.headers" + ViperKeyCourierSMSRequestConfig = "courier.sms.request_config" + ViperKeyCourierSMSEnabled = "courier.sms.enabled" + ViperKeyCourierSMSFrom = "courier.sms.from_name" ViperKeySecretsDefault = "secrets.default" ViperKeySecretsCookie = "secrets.cookie" ViperKeySecretsCipher = "secrets.cipher" @@ -817,6 +820,33 @@ func (p *Config) CourierSMTPHeaders() map[string]string { return p.p.StringMap(ViperKeyCourierSMTPHeaders) } +func (p *Config) CourierSMSRequestConfig() json.RawMessage { + if !p.p.Bool(ViperKeyCourierSMSEnabled) { + return nil + } + + out, err := p.p.Marshal(kjson.Parser()) + if err != nil { + p.l.WithError(err).Warn("Unable to marshal self service strategy configuration.") + return nil + } + + config := gjson.GetBytes(out, ViperKeyCourierSMSRequestConfig).Raw + if len(config) <= 0 { + return json.RawMessage("{}") + } + + return json.RawMessage(config) +} + +func (p *Config) CourierSMSFrom() string { + return p.p.StringF(ViperKeyCourierSMSFrom, "Ory Kratos") +} + +func (p *Config) CourierSMSEnabled() bool { + return p.p.Bool(ViperKeyCourierSMSEnabled) +} + func splitUrlAndFragment(s string) (string, string) { i := strings.IndexByte(s, '#') if i < 0 { diff --git a/driver/registry_default.go b/driver/registry_default.go index cd8f53e9b599..919bd414a890 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -259,11 +259,11 @@ func (m *RegistryDefault) Config(ctx context.Context) *config.Config { return corp.ContextualizeConfig(ctx, m.c) } -func (m *RegistryDefault) CourierConfig(ctx context.Context) courier.SMTPConfig { +func (m *RegistryDefault) CourierConfig(ctx context.Context) courier.Config { return m.Config(ctx) } -func (m *RegistryDefault) SMTPConfig(ctx context.Context) courier.SMTPConfig { +func (m *RegistryDefault) SMTPConfig(ctx context.Context) courier.Config { return m.Config(ctx) } @@ -586,8 +586,8 @@ func (m *RegistryDefault) SetPersister(p persistence.Persister) { m.persister = p } -func (m *RegistryDefault) Courier(ctx context.Context) *courier.Courier { - return courier.NewSMTP(ctx, m) +func (m *RegistryDefault) Courier(ctx context.Context) courier.Courier { + return courier.NewCourier(ctx, m) } func (m *RegistryDefault) ContinuityManager() continuity.Manager { diff --git a/embedx/config.schema.json b/embedx/config.schema.json index eafad4ae6984..d17c93ba53d8 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -1351,6 +1351,29 @@ "connection_uri" ], "additionalProperties": false + }, + "sms": { + "title": "SMS sender configuration", + "description": "Configures outgoing sms messages using HTTP protocol with generic SMS provider", + "type": "object", + "properties": { + "host": { + "title": "HTTP address of API endpoint", + "description": "This URL will be used to connect to SMS provider.", + "examples": [ + "https://api.twillio.com/sms/send" + ], + "type": "string", + "pattern": "^https?:\\/\\/.*" + }, + "from_name": { + "title": "SMS Sender Address", + "description": "The recipient of a sms will see this as the sender address.", + "type": "string", + "default": "Ory Kratos" + } + }, + "additionalProperties": false } }, "required": [ diff --git a/request/auth.go b/request/auth.go new file mode 100644 index 000000000000..65df14402fa7 --- /dev/null +++ b/request/auth.go @@ -0,0 +1,31 @@ +package request + +import ( + "encoding/json" + "fmt" + "net/http" +) + +type ( + AuthStrategy interface { + apply(req *http.Request) + } + + authStrategyFactory func(c json.RawMessage) (AuthStrategy, error) +) + +var strategyFactories = map[string]authStrategyFactory{ + "": newNoopAuthStrategy, + "api_key": newApiKeyStrategy, + "basic_auth": newBasicAuthStrategy, +} + +func authStrategy(name string, config json.RawMessage) (AuthStrategy, error) { + strategyFactory, ok := strategyFactories[name] + if ok { + return strategyFactory(config) + } + + return nil, fmt.Errorf("unsupported auth type: %s", name) + +} diff --git a/request/auth_strategy.go b/request/auth_strategy.go new file mode 100644 index 000000000000..e2e41b9e0f87 --- /dev/null +++ b/request/auth_strategy.go @@ -0,0 +1,76 @@ +package request + +import ( + "encoding/json" + "net/http" +) + +type ( + noopAuthStrategy struct{} + + basicAuthStrategy struct { + user string + password string + } + + apiKeyStrategy struct { + name string + value string + in string + } +) + +func newNoopAuthStrategy(_ json.RawMessage) (AuthStrategy, error) { + return &noopAuthStrategy{}, nil +} + +func (c *noopAuthStrategy) apply(_ *http.Request) {} + +func newBasicAuthStrategy(raw json.RawMessage) (AuthStrategy, error) { + type config struct { + User string + Password string + } + + var c config + if err := json.Unmarshal(raw, &c); err != nil { + return nil, err + } + + return &basicAuthStrategy{ + user: c.User, + password: c.Password, + }, nil +} + +func (c *basicAuthStrategy) apply(req *http.Request) { + req.SetBasicAuth(c.user, c.password) +} + +func newApiKeyStrategy(raw json.RawMessage) (AuthStrategy, error) { + type config struct { + In string + Name string + Value string + } + + var c config + if err := json.Unmarshal(raw, &c); err != nil { + return nil, err + } + + return &apiKeyStrategy{ + in: c.In, + name: c.Name, + value: c.Value, + }, nil +} + +func (c *apiKeyStrategy) apply(req *http.Request) { + switch c.in { + case "cookie": + req.AddCookie(&http.Cookie{Name: c.name, Value: c.value}) + default: + req.Header.Set(c.name, c.value) + } +} diff --git a/request/auth_strategy_test.go b/request/auth_strategy_test.go new file mode 100644 index 000000000000..b22d140c46e1 --- /dev/null +++ b/request/auth_strategy_test.go @@ -0,0 +1,67 @@ +package request + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNoopAuthStrategy(t *testing.T) { + req := http.Request{Header: map[string][]string{}} + auth := noopAuthStrategy{} + + auth.apply(&req) + + assert.Empty(t, req.Header, "Empty auth strategy shall not modify any request headers") +} + +func TestBasicAuthStrategy(t *testing.T) { + req := http.Request{Header: map[string][]string{}} + auth := basicAuthStrategy{ + user: "test-user", + password: "test-pass", + } + + auth.apply(&req) + + assert.Len(t, req.Header, 1) + + user, pass, _ := req.BasicAuth() + assert.Equal(t, "test-user", user) + assert.Equal(t, "test-pass", pass) +} + +func TestApiKeyInHeaderStrategy(t *testing.T) { + req := http.Request{Header: map[string][]string{}} + auth := apiKeyStrategy{ + in: "header", + name: "my-api-key-name", + value: "my-api-key-value", + } + + auth.apply(&req) + + require.Len(t, req.Header, 1) + + actualValue := req.Header.Get("my-api-key-name") + assert.Equal(t, "my-api-key-value", actualValue) +} + +func TestApiKeyInCookieStrategy(t *testing.T) { + req := http.Request{Header: map[string][]string{}} + auth := apiKeyStrategy{ + in: "cookie", + name: "my-api-key-name", + value: "my-api-key-value", + } + + auth.apply(&req) + + cookies := req.Cookies() + assert.Len(t, cookies, 1) + + assert.Equal(t, "my-api-key-name", cookies[0].Name) + assert.Equal(t, "my-api-key-value", cookies[0].Value) +} diff --git a/request/auth_test.go b/request/auth_test.go new file mode 100644 index 000000000000..c0df79336905 --- /dev/null +++ b/request/auth_test.go @@ -0,0 +1,56 @@ +package request + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthStrategy(t *testing.T) { + for _, tc := range map[string]struct { + name string + config string + expected AuthStrategy + }{ + "noop": { + name: "", + config: "", + expected: &noopAuthStrategy{}, + }, + "basic_auth": { + name: "basic_auth", + config: `{ + "user": "test-api-user", + "password": "secret" + }`, + expected: &basicAuthStrategy{}, + }, + "api-key/header": { + name: "api_key", + config: `{ + "in": "header", + "name": "my-api-key", + "value": "secret" + }`, + expected: &apiKeyStrategy{}, + }, + "api-key/cookie": { + name: "api_key", + config: `{ + "in": "cookie", + "name": "my-api-key", + "value": "secret" + }`, + expected: &apiKeyStrategy{}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + strategy, err := authStrategy(tc.name, json.RawMessage(tc.config)) + require.NoError(t, err) + + assert.IsTypef(t, tc.expected, strategy, "auth strategy should be of the expected type") + }) + } +} diff --git a/request/builder.go b/request/builder.go new file mode 100644 index 000000000000..6364a6baa7f0 --- /dev/null +++ b/request/builder.go @@ -0,0 +1,120 @@ +package request + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + + "github.com/google/go-jsonnet" + + "github.com/ory/x/fetcher" + "github.com/ory/x/logrusx" +) + +type Builder struct { + r *http.Request + log *logrusx.Logger + conf *Config +} + +func NewBuilder(config json.RawMessage, l *logrusx.Logger) (*Builder, error) { + c, err := parseConfig(config) + if err != nil { + return nil, err + } + + r, err := http.NewRequest(c.Method, c.Url, nil) + if err != nil { + return nil, err + } + + return &Builder{ + r: r, + log: l, + conf: c, + }, nil +} + +func (b *Builder) addAuth() error { + authConfig := b.conf.Auth + + strategy, err := authStrategy(authConfig.Type, authConfig.Config) + if err != nil { + return err + } + + strategy.apply(b.r) + + return nil +} + +func (b *Builder) addBody(body interface{}) error { + bodyReader, err := readBody(b.conf.TemplateURI, body, b.log) + if err != nil { + return err + } + + b.r.Body = io.NopCloser(bodyReader) + b.r.ContentLength = int64(bodyReader.Len()) + + return nil +} + +func (b *Builder) BuildRequest(body interface{}) (*http.Request, error) { + if err := b.addAuth(); err != nil { + return nil, err + } + + b.r.Header.Set("Content-Type", "application/json") + + // According to the HTTP spec any request method, but TRACE is allowed to + // have a body. Even this is a bad practice for some of them, like for GET + if b.conf.Method != http.MethodTrace { + if err := b.addBody(body); err != nil { + return nil, err + } + } + + return b.r, nil +} + +func readBody(templateURI string, data interface{}, l *logrusx.Logger) (*bytes.Reader, error) { + if templateURI == "" { + return nil, nil + } + + f := fetcher.NewFetcher() + + template, err := f.Fetch(templateURI) + if errors.Is(err, fetcher.ErrUnknownScheme) { + // legacy filepath + templateURI = "file://" + templateURI + l.WithError(err).Warnf("support for filepaths without a 'file://' scheme will be dropped in the next release, please use %s instead in your config", templateURI) + template, err = f.Fetch(templateURI) + } + // this handles the first error if it is a known scheme error, or the second fetch error + if err != nil { + return nil, err + } + + vm := jsonnet.MakeVM() + + buf := new(bytes.Buffer) + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(false) + enc.SetIndent("", "") + + if err := enc.Encode(data); err != nil { + return nil, err + } + vm.TLACode("ctx", buf.String()) + + res, err := vm.EvaluateAnonymousSnippet(templateURI, template.String()) + if err != nil { + return nil, err + } + + return bytes.NewReader([]byte(res)), nil +} diff --git a/request/builder_test.go b/request/builder_test.go new file mode 100644 index 000000000000..50b53c6e1437 --- /dev/null +++ b/request/builder_test.go @@ -0,0 +1,98 @@ +package request + +//import ( +// "github.com/ory/x/logrusx" +// "github.com/sirupsen/logrus/hooks/test" +// "io" +//) +// +////go:embed stub/test_body.jsonnet +//var testBodyJSONNet []byte +// +//func TestJsonNetSupport(t *testing.T) { +// f := &login.Flow{ID: x.NewUUID()} +// i := identity.NewIdentity("") +// l := logrusx.New("kratos", "test") +// +// for _, tc := range []struct { +// desc, template string +// data *templateContext +// }{ +// { +// desc: "simple file URI", +// template: "file://./stub/test_body.jsonnet", +// data: &templateContext{ +// Flow: f, +// RequestHeaders: http.Header{ +// "Cookie": []string{"c1=v1", "c2=v2"}, +// "Some-Header": []string{"Some-Value"}, +// }, +// RequestMethod: "POST", +// RequestUrl: "https://test.kratos.ory.sh/some-test-path", +// Identity: i, +// }, +// }, +// { +// desc: "legacy filepath without scheme", +// template: "./stub/test_body.jsonnet", +// data: &templateContext{ +// Flow: f, +// RequestHeaders: http.Header{ +// "Cookie": []string{"c1=v1", "c2=v2"}, +// "Some-Header": []string{"Some-Value"}, +// }, +// RequestMethod: "POST", +// RequestUrl: "https://test.kratos.ory.sh/some-test-path", +// Identity: i, +// }, +// }, +// { +// desc: "base64 encoded template URI", +// template: "base64://" + base64.StdEncoding.EncodeToString(testBodyJSONNet), +// data: &templateContext{ +// Flow: f, +// RequestHeaders: http.Header{ +// "Cookie": []string{"foo=bar"}, +// "My-Custom-Header": []string{"Cumstom-Value"}, +// }, +// RequestMethod: "PUT", +// RequestUrl: "https://test.kratos.ory.sh/other-test-path", +// Identity: i, +// }, +// }, +// } { +// t.Run("case="+tc.desc, func(t *testing.T) { +// b, err := readBody(l, tc.template, tc.data) +// require.NoError(t, err) +// body, err := io.ReadAll(b) +// require.NoError(t, err) +// +// expected, err := json.Marshal(map[string]interface{}{ +// "flow_id": tc.data.Flow.GetID(), +// "identity_id": tc.data.Identity.ID, +// "headers": tc.data.RequestHeaders, +// "method": tc.data.RequestMethod, +// "url": tc.data.RequestUrl, +// }) +// require.NoError(t, err) +// +// assert.JSONEq(t, string(expected), string(body)) +// }) +// } +// +// t.Run("case=warns about legacy usage", func(t *testing.T) { +// hook := test.Hook{} +// l := logrusx.New("kratos", "test", logrusx.WithHook(&hook)) +// +// _, _ = readBody(l, "./foo", nil) +// +// require.Len(t, hook.Entries, 1) +// assert.Contains(t, hook.LastEntry().Message, "support for filepaths without a 'file://' scheme will be dropped") +// }) +// +// t.Run("case=return non nil body reader on empty templateURI", func(t *testing.T) { +// body, err := readBody(l, "", nil) +// assert.NotNil(t, body) +// assert.Nil(t, err) +// }) +//} diff --git a/request/config.go b/request/config.go new file mode 100644 index 000000000000..bc6804ba1afc --- /dev/null +++ b/request/config.go @@ -0,0 +1,29 @@ +package request + +import ( + "encoding/json" +) + +type ( + Auth struct { + Type string + Config json.RawMessage + } + + Config struct { + Method string `json:"method"` + Url string `json:"url"` + TemplateURI string `json:"body"` + Auth Auth `json:"auth,omitempty"` + } +) + +func parseConfig(r json.RawMessage) (*Config, error) { + var c Config + err := json.Unmarshal(r, &c) + if err != nil { + return nil, err + } + + return &c, nil +} diff --git a/request/config_test.go b/request/config_test.go new file mode 100644 index 000000000000..bc21c96710d2 --- /dev/null +++ b/request/config_test.go @@ -0,0 +1,124 @@ +package request + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseConfig(t *testing.T) { + for _, tc := range []struct { + name string + rawConfig string + expectedConfig *Config + }{ + { + name: "no_auth", + expectedConfig: &Config{ + Method: http.MethodPost, + Url: "https://test.kratos.ory.sh/my_hook1", + TemplateURI: "/path/to/my/jsonnet1.file", + }, + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_hook1", + "method": "POST", + "body": "/path/to/my/jsonnet1.file" + }`, + }, + { + name: "basic_auth", + expectedConfig: &Config{ + Method: http.MethodGet, + Url: "https://test.kratos.ory.sh/my_hook2", + TemplateURI: "/path/to/my/jsonnet2.file", + Auth: Auth{ + Type: "basic_auth", + Config: json.RawMessage(`{ + "user": "test-api-user", + "password": "secret" + }`), + }, + }, + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_hook2", + "method": "GET", + "body": "/path/to/my/jsonnet2.file", + "auth": { + "type": "basic_auth", + "config": { + "user": "test-api-user", + "password": "secret" + } + } + }`, + }, + { + name: "api-key/header", + expectedConfig: &Config{ + Method: http.MethodDelete, + Url: "https://test.kratos.ory.sh/my_hook3", + TemplateURI: "/path/to/my/jsonnet3.file", + Auth: Auth{ + Type: "api_key", + Config: json.RawMessage(`{ + "in": "header", + "name": "my-api-key", + "value": "secret" + }`), + }, + }, + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_hook3", + "method": "DELETE", + "body": "/path/to/my/jsonnet3.file", + "auth": { + "type": "api_key", + "config": { + "in": "header", + "name": "my-api-key", + "value": "secret" + } + } + }`, + }, + { + name: "api-key/cookie", + expectedConfig: &Config{ + Method: http.MethodPost, + Url: "https://test.kratos.ory.sh/my_hook4", + TemplateURI: "/path/to/my/jsonnet4.file", + Auth: Auth{ + Type: "api_key", + Config: json.RawMessage(`{ + "in": "cookie", + "name": "my-api-key", + "value": "secret" + }`), + }, + }, + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_hook4", + "method": "POST", + "body": "/path/to/my/jsonnet4.file", + "auth": { + "type": "api_key", + "config": { + "in": "cookie", + "name": "my-api-key", + "value": "secret" + } + } + }`, + }, + } { + t.Run("auth-strategy="+tc.name, func(t *testing.T) { + conf, err := parseConfig([]byte(tc.rawConfig)) + require.NoError(t, err) + + assert.Equal(t, tc.expectedConfig, conf) + }) + } +} diff --git a/selfservice/hook/web_hook.go b/selfservice/hook/web_hook.go index b3b5da1b5093..6291c37dfbd5 100644 --- a/selfservice/hook/web_hook.go +++ b/selfservice/hook/web_hook.go @@ -1,19 +1,12 @@ package hook import ( - "bytes" "encoding/json" "fmt" - "io" "net/http" - "github.com/ory/x/fetcher" - "github.com/ory/x/logrusx" - - "github.com/google/go-jsonnet" - "github.com/pkg/errors" - "github.com/ory/kratos/identity" + "github.com/ory/kratos/request" "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/recovery" @@ -29,32 +22,6 @@ var _ verification.PostHookExecutor = new(WebHook) var _ recovery.PostHookExecutor = new(WebHook) type ( - AuthStrategy interface { - apply(req *http.Request) - } - - authStrategyFactory func(c json.RawMessage) (AuthStrategy, error) - - noopAuthStrategy struct{} - - basicAuthStrategy struct { - user string - password string - } - - apiKeyStrategy struct { - name string - value string - in string - } - - webHookConfig struct { - method string - url string - templateURI string - auth AuthStrategy - } - webHookDependencies interface { x.LoggingProvider } @@ -68,113 +35,13 @@ type ( } WebHook struct { - r webHookDependencies - c json.RawMessage + deps webHookDependencies + conf json.RawMessage } ) -var strategyFactories = map[string]authStrategyFactory{ - "": newNoopAuthStrategy, - "api_key": newApiKeyStrategy, - "basic_auth": newBasicAuthStrategy, -} - -func newAuthStrategy(name string, c json.RawMessage) (as AuthStrategy, err error) { - if f, ok := strategyFactories[name]; ok { - as, err = f(c) - } else { - err = fmt.Errorf("unsupported auth type: %s", name) - } - return -} - -func newNoopAuthStrategy(_ json.RawMessage) (AuthStrategy, error) { - return &noopAuthStrategy{}, nil -} - -func (c *noopAuthStrategy) apply(_ *http.Request) {} - -func newBasicAuthStrategy(raw json.RawMessage) (AuthStrategy, error) { - type config struct { - User string - Password string - } - - var c config - if err := json.Unmarshal(raw, &c); err != nil { - return nil, err - } - - return &basicAuthStrategy{ - user: c.User, - password: c.Password, - }, nil -} - -func (c *basicAuthStrategy) apply(req *http.Request) { - req.SetBasicAuth(c.user, c.password) -} - -func newApiKeyStrategy(raw json.RawMessage) (AuthStrategy, error) { - type config struct { - In string - Name string - Value string - } - - var c config - if err := json.Unmarshal(raw, &c); err != nil { - return nil, err - } - - return &apiKeyStrategy{ - in: c.In, - name: c.Name, - value: c.Value, - }, nil -} - -func (c *apiKeyStrategy) apply(req *http.Request) { - switch c.in { - case "cookie": - req.AddCookie(&http.Cookie{Name: c.name, Value: c.value}) - default: - req.Header.Set(c.name, c.value) - } -} - -func newWebHookConfig(r json.RawMessage) (*webHookConfig, error) { - type rawWebHookConfig struct { - Method string - Url string - Body string - Auth struct { - Type string - Config json.RawMessage - } - } - - var rc rawWebHookConfig - err := json.Unmarshal(r, &rc) - if err != nil { - return nil, err - } - - as, err := newAuthStrategy(rc.Auth.Type, rc.Auth.Config) - if err != nil { - return nil, fmt.Errorf("failed to create web hook auth strategy: %w", err) - } - - return &webHookConfig{ - method: rc.Method, - url: rc.Url, - templateURI: rc.Body, - auth: as, - }, nil -} - func NewWebHook(r webHookDependencies, c json.RawMessage) *WebHook { - return &WebHook{r: r, c: c} + return &WebHook{deps: r, conf: c} } func (e *WebHook) ExecuteLoginPreHook(_ http.ResponseWriter, req *http.Request, flow *login.Flow) error { @@ -246,84 +113,21 @@ func (e *WebHook) ExecuteSettingsPostPersistHook(_ http.ResponseWriter, req *htt } func (e *WebHook) execute(data *templateContext) error { - // TODO: reminder for the future: move parsing of config to the web hook initialization - conf, err := newWebHookConfig(e.c) - if err != nil { - return fmt.Errorf("failed to parse web hook config: %w", err) - } - - var body io.Reader - if conf.method != "TRACE" { - // According to the HTTP spec any request method, but TRACE is allowed to - // have a body. Even this is a really bad practice for some of them, like for - // GET - body, err = createBody(e.r.Logger(), conf.templateURI, data) - if err != nil { - return fmt.Errorf("failed to create web hook body: %w", err) - } - } - - if body == nil { - body = bytes.NewReader(make([]byte, 0)) - } - if err = doHttpCall(conf.method, conf.url, conf.auth, body); err != nil { - return fmt.Errorf("failed to call web hook %w", err) - } - return nil -} - -func createBody(l *logrusx.Logger, templateURI string, data *templateContext) (*bytes.Reader, error) { - if len(templateURI) == 0 { - return bytes.NewReader(make([]byte, 0)), nil - } - - f := fetcher.NewFetcher() - - template, err := f.Fetch(templateURI) - if errors.Is(err, fetcher.ErrUnknownScheme) { - // legacy filepath - templateURI = "file://" + templateURI - l.WithError(err).Warnf("support for filepaths without a 'file://' scheme will be dropped in the next release, please use %s instead in your config", templateURI) - template, err = f.Fetch(templateURI) - } - // this handles the first error if it is a known scheme error, or the second fetch error + builder, err := request.NewBuilder(e.conf, e.deps.Logger()) if err != nil { - return nil, err - } - - vm := jsonnet.MakeVM() - - buf := new(bytes.Buffer) - enc := json.NewEncoder(buf) - enc.SetEscapeHTML(false) - enc.SetIndent("", "") - - if err := enc.Encode(data); err != nil { - return nil, err + return err } - vm.TLACode("ctx", buf.String()) - if res, err := vm.EvaluateAnonymousSnippet(templateURI, template.String()); err != nil { - return nil, err - } else { - return bytes.NewReader([]byte(res)), nil - } -} - -func doHttpCall(method string, url string, as AuthStrategy, body io.Reader) error { - req, err := http.NewRequest(method, url, body) + req, err := builder.BuildRequest(data) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - - as.apply(req) resp, err := http.DefaultClient.Do(req) - if err != nil { return err - } else if resp.StatusCode >= 400 { + } + if resp.StatusCode >= http.StatusBadRequest { return fmt.Errorf("web hook failed with status code %v", resp.StatusCode) } diff --git a/selfservice/hook/web_hook_test.go b/selfservice/hook/web_hook_test.go index 27b3166fdda6..9ba8691929d9 100644 --- a/selfservice/hook/web_hook_test.go +++ b/selfservice/hook/web_hook_test.go @@ -5,346 +5,35 @@ import ( "encoding/base64" "encoding/json" "fmt" - "io" "io/ioutil" "net/http" "net/http/httptest" "strconv" "testing" - "github.com/sirupsen/logrus/hooks/test" - - "github.com/ory/x/logrusx" + "github.com/julienschmidt/httprouter" + "github.com/stretchr/testify/assert" + "github.com/ory/kratos/identity" + "github.com/ory/kratos/selfservice/flow" + "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/recovery" "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/selfservice/flow/settings" "github.com/ory/kratos/selfservice/flow/verification" - - "github.com/ory/kratos/selfservice/flow" - - "github.com/julienschmidt/httprouter" - - "github.com/ory/kratos/identity" - "github.com/ory/kratos/x" - "github.com/ory/kratos/session" - - "github.com/stretchr/testify/require" - - "github.com/ory/kratos/selfservice/flow/login" - - "github.com/stretchr/testify/assert" + "github.com/ory/kratos/x" + "github.com/ory/x/logrusx" ) -func TestNoopAuthStrategy(t *testing.T) { - req := http.Request{Header: map[string][]string{}} - auth := noopAuthStrategy{} - - auth.apply(&req) - - assert.Empty(t, req.Header, "Empty auth strategy shall not modify any request headers") -} - -func TestBasicAuthStrategy(t *testing.T) { - req := http.Request{Header: map[string][]string{}} - auth := basicAuthStrategy{ - user: "test-user", - password: "test-pass", - } - - auth.apply(&req) - - assert.Len(t, req.Header, 1) - - user, pass, _ := req.BasicAuth() - assert.Equal(t, "test-user", user) - assert.Equal(t, "test-pass", pass) -} - -func TestApiKeyInHeaderStrategy(t *testing.T) { - req := http.Request{Header: map[string][]string{}} - auth := apiKeyStrategy{ - in: "header", - name: "my-api-key-name", - value: "my-api-key-value", - } - - auth.apply(&req) - - require.Len(t, req.Header, 1) - - actualValue := req.Header.Get("my-api-key-name") - assert.Equal(t, "my-api-key-value", actualValue) -} - -func TestApiKeyInCookieStrategy(t *testing.T) { - req := http.Request{Header: map[string][]string{}} - auth := apiKeyStrategy{ - in: "cookie", - name: "my-api-key-name", - value: "my-api-key-value", - } - - auth.apply(&req) - - cookies := req.Cookies() - assert.Len(t, cookies, 1) - - assert.Equal(t, "my-api-key-name", cookies[0].Name) - assert.Equal(t, "my-api-key-value", cookies[0].Value) -} - -//go:embed stub/test_body.jsonnet -var testBodyJSONNet []byte - -func TestJsonNetSupport(t *testing.T) { - f := &login.Flow{ID: x.NewUUID()} - i := identity.NewIdentity("") - l := logrusx.New("kratos", "test") - - for _, tc := range []struct { - desc, template string - data *templateContext - }{ - { - desc: "simple file URI", - template: "file://./stub/test_body.jsonnet", - data: &templateContext{ - Flow: f, - RequestHeaders: http.Header{ - "Cookie": []string{"c1=v1", "c2=v2"}, - "Some-Header": []string{"Some-Value"}, - }, - RequestMethod: "POST", - RequestUrl: "https://test.kratos.ory.sh/some-test-path", - Identity: i, - }, - }, - { - desc: "legacy filepath without scheme", - template: "./stub/test_body.jsonnet", - data: &templateContext{ - Flow: f, - RequestHeaders: http.Header{ - "Cookie": []string{"c1=v1", "c2=v2"}, - "Some-Header": []string{"Some-Value"}, - }, - RequestMethod: "POST", - RequestUrl: "https://test.kratos.ory.sh/some-test-path", - Identity: i, - }, - }, - { - desc: "base64 encoded template URI", - template: "base64://" + base64.StdEncoding.EncodeToString(testBodyJSONNet), - data: &templateContext{ - Flow: f, - RequestHeaders: http.Header{ - "Cookie": []string{"foo=bar"}, - "My-Custom-Header": []string{"Cumstom-Value"}, - }, - RequestMethod: "PUT", - RequestUrl: "https://test.kratos.ory.sh/other-test-path", - Identity: i, - }, - }, - } { - t.Run("case="+tc.desc, func(t *testing.T) { - b, err := createBody(l, tc.template, tc.data) - require.NoError(t, err) - body, err := io.ReadAll(b) - require.NoError(t, err) - - expected, err := json.Marshal(map[string]interface{}{ - "flow_id": tc.data.Flow.GetID(), - "identity_id": tc.data.Identity.ID, - "headers": tc.data.RequestHeaders, - "method": tc.data.RequestMethod, - "url": tc.data.RequestUrl, - }) - require.NoError(t, err) - - assert.JSONEq(t, string(expected), string(body)) - }) - } - - t.Run("case=warns about legacy usage", func(t *testing.T) { - hook := test.Hook{} - l := logrusx.New("kratos", "test", logrusx.WithHook(&hook)) - - _, _ = createBody(l, "./foo", nil) - - require.Len(t, hook.Entries, 1) - assert.Contains(t, hook.LastEntry().Message, "support for filepaths without a 'file://' scheme will be dropped") - }) - - t.Run("case=return non nil body reader on empty templateURI", func(t *testing.T) { - body, err := createBody(l, "", nil) - assert.NotNil(t, body) - assert.Nil(t, err) - }) -} - -func TestWebHookConfig(t *testing.T) { - for _, tc := range []struct { - strategy string - method string - url string - body string - rawConfig string - authStrategy AuthStrategy - }{ - { - strategy: "empty", - method: "POST", - url: "https://test.kratos.ory.sh/my_hook1", - body: "/path/to/my/jsonnet1.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook1", - "method": "POST", - "body": "/path/to/my/jsonnet1.file" - }`, - authStrategy: &noopAuthStrategy{}, - }, - { - strategy: "basic_auth", - method: "GET", - url: "https://test.kratos.ory.sh/my_hook2", - body: "/path/to/my/jsonnet2.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook2", - "method": "GET", - "body": "/path/to/my/jsonnet2.file", - "auth": { - "type": "basic_auth", - "config": { - "user": "test-api-user", - "password": "secret" - } - } - }`, - authStrategy: &basicAuthStrategy{}, - }, - { - strategy: "api-key/header", - method: "DELETE", - url: "https://test.kratos.ory.sh/my_hook3", - body: "/path/to/my/jsonnet3.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook3", - "method": "DELETE", - "body": "/path/to/my/jsonnet3.file", - "auth": { - "type": "api_key", - "config": { - "in": "header", - "name": "my-api-key", - "value": "secret" - } - } - }`, - authStrategy: &apiKeyStrategy{}, - }, - { - strategy: "api-key/cookie", - method: "POST", - url: "https://test.kratos.ory.sh/my_hook4", - body: "/path/to/my/jsonnet4.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook4", - "method": "POST", - "body": "/path/to/my/jsonnet4.file", - "auth": { - "type": "api_key", - "config": { - "in": "cookie", - "name": "my-api-key", - "value": "secret" - } - } - }`, - authStrategy: &apiKeyStrategy{}, - }, - } { - t.Run("auth-strategy="+tc.strategy, func(t *testing.T) { - conf, err := newWebHookConfig([]byte(tc.rawConfig)) - assert.Nil(t, err) - - assert.Equal(t, tc.url, conf.url) - assert.Equal(t, tc.method, conf.method) - assert.Equal(t, tc.body, conf.templateURI) - assert.NotNil(t, conf.auth) - assert.IsTypef(t, tc.authStrategy, conf.auth, "Auth should be of the expected type") - }) - } +type webHookRequest struct { + Body string + Headers http.Header + Method string } func TestWebHooks(t *testing.T) { - type WebHookRequest struct { - Body string - Headers http.Header - Method string - } - - webHookEndPoint := func(whr *WebHookRequest) httprouter.Handle { - return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - body, err := ioutil.ReadAll(r.Body) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - whr.Body = string(body) - whr.Headers = r.Header - whr.Method = r.Method - } - } - - webHookHttpCodeEndPoint := func(code int) httprouter.Handle { - return func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { - w.WriteHeader(code) - } - } - - path := "/web_hook" - newServer := func(f httprouter.Handle) *httptest.Server { - r := httprouter.New() - - r.Handle("CONNECT", path, f) - r.DELETE(path, f) - r.GET(path, f) - r.OPTIONS(path, f) - r.PATCH(path, f) - r.POST(path, f) - r.PUT(path, f) - r.Handle("TRACE", path, f) - - ts := httptest.NewServer(r) - t.Cleanup(ts.Close) - return ts - } - - bodyWithFlowOnly := func(req *http.Request, f flow.Flow) string { - h, _ := json.Marshal(req.Header) - return fmt.Sprintf(`{ - "flow_id": "%s", - "identity_id": null, - "headers": %s, - "method": "%s", - "url": "%s" - }`, f.GetID(), string(h), req.Method, req.RequestURI) - } - - bodyWithFlowAndIdentity := func(req *http.Request, f flow.Flow, s *session.Session) string { - h, _ := json.Marshal(req.Header) - return fmt.Sprintf(`{ - "flow_id": "%s", - "identity_id": "%s", - "headers": %s, - "method": "%s", - "url": "%s" - }`, f.GetID(), s.Identity.ID, string(h), req.Method, req.RequestURI) - } + const path = "/web_hook" for _, tc := range []struct { uc string @@ -485,15 +174,19 @@ func TestWebHooks(t *testing.T) { t.Run("auth="+auth.uc, func(t *testing.T) { for _, method := range []string{"CONNECT", "DELETE", "GET", "OPTIONS", "PATCH", "POST", "PUT", "TRACE", "GARBAGE"} { t.Run("method="+method, func(t *testing.T) { - f := tc.createFlow() req := &http.Request{ Header: map[string][]string{"Some-Header": {"Some-Value"}}, RequestURI: "https://www.ory.sh/some_end_point", Method: http.MethodPost, } + + f := tc.createFlow() + s := &session.Session{ID: x.NewUUID(), Identity: &identity.Identity{ID: x.NewUUID()}} - whr := &WebHookRequest{} - ts := newServer(webHookEndPoint(whr)) + + whr := &webHookRequest{} + ts := newServer(t, path, webHookEndPoint(whr)) + conf := json.RawMessage(fmt.Sprintf(`{ "url": "%s", "method": "%s", @@ -521,11 +214,11 @@ func TestWebHooks(t *testing.T) { assert.Equal(t, v, vals) } - if method != "TRACE" { + if method != http.MethodTrace { // According to the HTTP spec any request method, but TRACE is allowed to - // have a body. Even this is a really bad practice for some of them, like for - // GET - assert.JSONEq(t, tc.expectedBody(req, f, s), whr.Body) + // have a body. Even this is a bad practice for some of them, like for GET + expectedBody := tc.expectedBody(req, f, s) + assert.JSONEq(t, expectedBody, whr.Body) } else { assert.Emptyf(t, whr.Body, "HTTP %s is not allowed to have a body", method) } @@ -551,7 +244,7 @@ func TestWebHooks(t *testing.T) { }) t.Run("Must error when template is erroneous", func(t *testing.T) { - ts := newServer(webHookHttpCodeEndPoint(200)) + ts := newServer(t, path, webHookHttpCodeEndPoint(http.StatusOK)) req := &http.Request{ Header: map[string][]string{"Some-Header": {"Some-Value"}}, RequestURI: "https://www.ory.sh/some_end_point", @@ -591,7 +284,7 @@ func TestWebHooks(t *testing.T) { {599, false}, } { t.Run("Must"+boolToString(tc.mustSuccess)+" error when end point is returning "+strconv.Itoa(tc.code), func(t *testing.T) { - ts := newServer(webHookHttpCodeEndPoint(tc.code)) + ts := newServer(t, path, webHookHttpCodeEndPoint(tc.code)) req := &http.Request{ Header: map[string][]string{"Some-Header": {"Some-Value"}}, RequestURI: "https://www.ory.sh/some_end_point", @@ -614,3 +307,67 @@ func TestWebHooks(t *testing.T) { }) } } + +func newServer(t *testing.T, path string, f httprouter.Handle) *httptest.Server { + r := httprouter.New() + + r.Handle("CONNECT", path, f) + r.DELETE(path, f) + r.GET(path, f) + r.OPTIONS(path, f) + r.PATCH(path, f) + r.POST(path, f) + r.PUT(path, f) + r.Handle("TRACE", path, f) + + ts := httptest.NewServer(r) + t.Cleanup(ts.Close) + return ts +} + +func webHookHttpCodeEndPoint(code int) httprouter.Handle { + return func(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { + w.WriteHeader(code) + } +} + +func webHookEndPoint(whr *webHookRequest) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + whr.Body = string(body) + whr.Headers = r.Header + whr.Method = r.Method + } +} + +func bodyWithFlowOnly(req *http.Request, f flow.Flow) string { + h, _ := json.Marshal(req.Header) + + const config = `{ + "flow_id": "%s", + "identity_id": null, + "headers": %s, + "method": "%s", + "url": "%s" + }` + + return fmt.Sprintf(config, f.GetID(), string(h), req.Method, req.RequestURI) +} + +func bodyWithFlowAndIdentity(req *http.Request, f flow.Flow, s *session.Session) string { + h, _ := json.Marshal(req.Header) + + const configFmt = `{ + "flow_id": "%s", + "identity_id": "%s", + "headers": %s, + "method": "%s", + "url": "%s" + }` + + return fmt.Sprintf(configFmt, f.GetID(), s.Identity.ID, string(h), req.Method, req.RequestURI) +} diff --git a/selfservice/strategy/link/sender.go b/selfservice/strategy/link/sender.go index 77aaee20c518..80e91123f800 100644 --- a/selfservice/strategy/link/sender.go +++ b/selfservice/strategy/link/sender.go @@ -5,6 +5,8 @@ import ( "net/http" "net/url" + "github.com/ory/kratos/courier/template/email" + "github.com/pkg/errors" "github.com/ory/x/errorsx" @@ -12,7 +14,6 @@ import ( "github.com/ory/x/urlx" "github.com/ory/kratos/courier" - templates "github.com/ory/kratos/courier/template" "github.com/ory/kratos/driver/config" "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow/recovery" @@ -59,7 +60,7 @@ func (s *Sender) SendRecoveryLink(ctx context.Context, r *http.Request, f *recov address, err := s.r.IdentityPool().FindRecoveryAddressByValue(ctx, identity.RecoveryAddressTypeEmail, to) if err != nil { - if err := s.send(ctx, string(via), templates.NewRecoveryInvalid(s.r.Config(ctx), &templates.RecoveryInvalidModel{To: to})); err != nil { + if err := s.send(ctx, string(via), email.NewRecoveryInvalid(s.r.Config(ctx), &email.RecoveryInvalidModel{To: to})); err != nil { return err } return errors.Cause(ErrUnknownAddress) @@ -99,7 +100,7 @@ func (s *Sender) SendVerificationLink(ctx context.Context, f *verification.Flow, WithField("via", via). WithSensitiveField("email_address", address). Info("Sending out invalid verification email because address is unknown.") - if err := s.send(ctx, string(via), templates.NewVerificationInvalid(s.r.Config(ctx), &templates.VerificationInvalidModel{To: to})); err != nil { + if err := s.send(ctx, string(via), email.NewVerificationInvalid(s.r.Config(ctx), &email.VerificationInvalidModel{To: to})); err != nil { return err } return errors.Cause(ErrUnknownAddress) @@ -138,8 +139,8 @@ func (s *Sender) SendRecoveryTokenTo(ctx context.Context, f *recovery.Flow, i *i return err } - return s.send(ctx, string(address.Via), templates.NewRecoveryValid(s.r.Config(ctx), - &templates.RecoveryValidModel{To: address.Value, RecoveryURL: urlx.CopyWithQuery( + return s.send(ctx, string(address.Via), email.NewRecoveryValid(s.r.Config(ctx), + &email.RecoveryValidModel{To: address.Value, RecoveryURL: urlx.CopyWithQuery( urlx.AppendPaths(s.r.Config(ctx).SelfServiceLinkMethodBaseURL(), recovery.RouteSubmitFlow), url.Values{ "token": {token.Token}, @@ -161,8 +162,8 @@ func (s *Sender) SendVerificationTokenTo(ctx context.Context, f *verification.Fl return err } - if err := s.send(ctx, string(address.Via), templates.NewVerificationValid(s.r.Config(ctx), - &templates.VerificationValidModel{To: address.Value, VerificationURL: urlx.CopyWithQuery( + if err := s.send(ctx, string(address.Via), email.NewVerificationValid(s.r.Config(ctx), + &email.VerificationValidModel{To: address.Value, VerificationURL: urlx.CopyWithQuery( urlx.AppendPaths(s.r.Config(ctx).SelfServiceLinkMethodBaseURL(), verification.RouteSubmitFlow), url.Values{ "flow": {f.ID.String()}, diff --git a/x/require.go b/x/require.go index 17154b1593f6..689dc6a1aab5 100644 --- a/x/require.go +++ b/x/require.go @@ -5,6 +5,7 @@ import ( "encoding/json" "testing" + "github.com/gofrs/uuid" "github.com/stretchr/testify/require" ) @@ -13,3 +14,7 @@ func RequireJSONMarshal(t *testing.T, in interface{}) []byte { require.NoError(t, json.NewEncoder(&b).Encode(in)) return b.Bytes() } + +func RequireNotNilUUID(t *testing.T, id uuid.UUID) { + require.NotEqual(t, uuid.Nil, id) +}