diff --git a/courier/courier.go b/courier/courier.go index 329a95a48d27..d5253038229a 100644 --- a/courier/courier.go +++ b/courier/courier.go @@ -2,144 +2,68 @@ package courier import ( "context" - "crypto/tls" - "encoding/json" - "fmt" - "strconv" "time" - "github.com/hashicorp/go-retryablehttp" - - "github.com/ory/kratos/driver/config" - "github.com/ory/x/httpx" - "github.com/cenkalti/backoff" "github.com/gofrs/uuid" + "github.com/hashicorp/go-retryablehttp" "github.com/pkg/errors" - "github.com/ory/herodot" - - gomail "github.com/ory/mail/v3" - + "github.com/ory/kratos/driver/config" "github.com/ory/kratos/x" + gomail "github.com/ory/mail/v3" + "github.com/ory/x/httpx" ) type ( - SMTPDependencies interface { + Dependencies interface { PersistenceProvider x.LoggingProvider ConfigProvider HTTPClient(ctx context.Context, opts ...httpx.ResilientOptions) *retryablehttp.Client } - TemplateTyper func(t EmailTemplate) (TemplateType, error) - EmailTemplateFromMessage func(d SMTPDependencies, msg Message) (EmailTemplate, error) - Courier struct { - Dialer *gomail.Dialer - d SMTPDependencies - GetTemplateType TemplateTyper - NewEmailTemplateFromMessage EmailTemplateFromMessage + + Courier interface { + Work(ctx context.Context) error + QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) + QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error) + SmtpDialer() *gomail.Dialer + DispatchQueue(ctx context.Context) error } + Provider interface { - Courier(ctx context.Context) *Courier + Courier(ctx context.Context) Courier } + ConfigProvider interface { CourierConfig(ctx context.Context) config.CourierConfigs } -) - -func NewSMTP(ctx context.Context, d SMTPDependencies) *Courier { - uri := d.CourierConfig(ctx).CourierSMTPURL() - password, _ := uri.User.Password() - port, _ := strconv.ParseInt(uri.Port(), 10, 0) - - dialer := &gomail.Dialer{ - Host: uri.Hostname(), - Port: int(port), - Username: uri.User.Username(), - Password: password, - - Timeout: time.Second * 10, - RetryFailure: true, - } - - sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify")) - - // SMTP schemes - // smtp: smtp clear text (with uri parameter) or with StartTLS (enforced by default) - // smtps: smtp with implicit TLS (recommended way in 2021 to avoid StartTLS downgrade attacks - // and defaulting to fully-encrypted protocols https://datatracker.ietf.org/doc/html/rfc8314) - switch uri.Scheme { - case "smtp": - // Enforcing StartTLS by default for security best practices (config review, etc.) - skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls")) - if !skipStartTLS { - // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. - dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} - // Enforcing StartTLS - dialer.StartTLSPolicy = gomail.MandatoryStartTLS - } - case "smtps": - // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. - dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} - dialer.SSL = true + courier struct { + smsClient *smsClient + smtpClient *smtpClient + deps Dependencies + failOnError bool } +) - return &Courier{ - d: d, - Dialer: dialer, - GetTemplateType: GetTemplateType, - NewEmailTemplateFromMessage: NewEmailTemplateFromMessage, +func NewCourier(ctx context.Context, deps Dependencies) Courier { + return &courier{ + smsClient: newSMS(ctx, deps), + smtpClient: newSMTP(ctx, deps), + deps: deps, } } -func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) { - recipient, err := t.EmailRecipient() - if err != nil { - return uuid.Nil, err - } - - subject, err := t.EmailSubject(ctx) - if err != nil { - return uuid.Nil, err - } - - bodyPlaintext, err := t.EmailBodyPlaintext(ctx) - if err != nil { - return uuid.Nil, err - } - - templateType, err := m.GetTemplateType(t) - if err != nil { - return uuid.Nil, err - } - - templateData, err := json.Marshal(t) - if err != nil { - return uuid.Nil, err - } - - message := &Message{ - Status: MessageStatusQueued, - Type: MessageTypeEmail, - Recipient: recipient, - Body: bodyPlaintext, - Subject: subject, - TemplateType: templateType, - TemplateData: templateData, - } - - if err := m.d.CourierPersister().AddMessage(ctx, message); err != nil { - return uuid.Nil, err - } - return message.ID, nil +func (c *courier) FailOnDispatchError() { + c.failOnError = true } -func (m *Courier) Work(ctx context.Context) error { +func (c *courier) Work(ctx context.Context) error { errChan := make(chan error) defer close(errChan) - go m.watchMessages(ctx, errChan) + go c.watchMessages(ctx, errChan) select { case <-ctx.Done(): @@ -152,10 +76,10 @@ func (m *Courier) Work(ctx context.Context) error { } } -func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { +func (c *courier) watchMessages(ctx context.Context, errChan chan error) { for { if err := backoff.Retry(func() error { - return m.DispatchQueue(ctx) + return c.DispatchQueue(ctx) }, backoff.NewExponentialBackOff()); err != nil { errChan <- err return @@ -163,105 +87,3 @@ func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { time.Sleep(time.Second) } } - -func (m *Courier) DispatchMessage(ctx context.Context, msg Message) error { - switch msg.Type { - case MessageTypeEmail: - from := m.d.CourierConfig(ctx).CourierSMTPFrom() - fromName := m.d.CourierConfig(ctx).CourierSMTPFromName() - gm := gomail.NewMessage() - if fromName == "" { - gm.SetHeader("From", from) - } else { - gm.SetAddressHeader("From", from, fromName) - } - - gm.SetHeader("To", msg.Recipient) - gm.SetHeader("Subject", msg.Subject) - - headers := m.d.CourierConfig(ctx).CourierSMTPHeaders() - for k, v := range headers { - gm.SetHeader(k, v) - } - - gm.SetBody("text/plain", msg.Body) - - tmpl, err := m.NewEmailTemplateFromMessage(m.d, msg) - if err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to get email template from message.`) - } else { - htmlBody, err := tmpl.EmailBody(ctx) - if err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to get email body from template.`) - } else { - gm.AddAlternative("text/html", htmlBody) - } - } - - if err := m.Dialer.DialAndSend(ctx, gm); err != nil { - m.d.Logger(). - WithError(err). - WithField("smtp_server", fmt.Sprintf("%s:%d", m.Dialer.Host, m.Dialer.Port)). - WithField("smtp_ssl_enabled", m.Dialer.SSL). - // WithField("email_to", msg.Recipient). - WithField("message_from", from). - Error("Unable to send email using SMTP connection.") - return errors.WithStack(err) - } - - if err := m.d.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", msg.ID). - Error(`Unable to set the message status to "sent".`) - return err - } - - m.d.Logger(). - WithField("message_id", msg.ID). - WithField("message_type", msg.Type). - WithField("message_template_type", msg.TemplateType). - WithField("message_subject", msg.Subject). - Debug("Courier sent out message.") - return nil - } - return errors.Errorf("received unexpected message type: %d", msg.Type) -} - -func (m *Courier) DispatchQueue(ctx context.Context) error { - if len(m.Dialer.Host) == 0 { - return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an email but courier.smtp_url is not set!")) - } - - messages, err := m.d.CourierPersister().NextMessages(ctx, 10) - if err != nil { - if errors.Is(err, ErrQueueEmpty) { - return nil - } - return err - } - - for k := range messages { - var msg = messages[k] - if err := m.DispatchMessage(ctx, msg); err != nil { - for _, replace := range messages[k:] { - if err := m.d.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil { - m.d.Logger(). - WithError(err). - WithField("message_id", replace.ID). - Error(`Unable to reset the failed message's status to "queued".`) - } - } - - return err - } - } - - return nil -} diff --git a/courier/courier_dispatcher.go b/courier/courier_dispatcher.go new file mode 100644 index 000000000000..6810fbd50cc2 --- /dev/null +++ b/courier/courier_dispatcher.go @@ -0,0 +1,70 @@ +package courier + +import ( + "context" + + "github.com/pkg/errors" +) + +func (c *courier) DispatchMessage(ctx context.Context, msg Message) error { + switch msg.Type { + case MessageTypeEmail: + if err := c.dispatchEmail(ctx, msg); err != nil { + return err + } + case MessageTypePhone: + if err := c.dispatchSMS(ctx, msg); err != nil { + return err + } + default: + return errors.Errorf("received unexpected message type: %d", msg.Type) + } + + if err := c.deps.CourierPersister().SetMessageStatus(ctx, msg.ID, MessageStatusSent); err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to set the message status to "sent".`) + return err + } + + c.deps.Logger(). + WithField("message_id", msg.ID). + WithField("message_type", msg.Type). + WithField("message_template_type", msg.TemplateType). + WithField("message_subject", msg.Subject). + Debug("Courier sent out message.") + + return nil +} + +func (c *courier) DispatchQueue(ctx context.Context) error { + messages, err := c.deps.CourierPersister().NextMessages(ctx, 10) + if err != nil { + if errors.Is(err, ErrQueueEmpty) { + return nil + } + return err + } + + for k := range messages { + var msg = messages[k] + if err := c.DispatchMessage(ctx, msg); err != nil { + for _, replace := range messages[k:] { + if err := c.deps.CourierPersister().SetMessageStatus(ctx, replace.ID, MessageStatusQueued); err != nil { + if c.failOnError { + return err + } + c.deps.Logger(). + WithError(err). + WithField("message_id", replace.ID). + Error(`Unable to reset the failed message's status to "queued".`) + } + } + + return err + } + } + + return nil +} diff --git a/courier/courier_test.go b/courier/courier_test.go index d47ba8e7fd41..b871e19fbe54 100644 --- a/courier/courier_test.go +++ b/courier/courier_test.go @@ -1,30 +1,10 @@ package courier_test import ( - "context" - "fmt" - "io/ioutil" - "net/http" "testing" - "time" - - "github.com/sirupsen/logrus" "github.com/ory/kratos/x" - gomail "github.com/ory/mail/v3" - - "github.com/gofrs/uuid" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" - dhelper "github.com/ory/x/sqlcon/dockertest" - - courier "github.com/ory/kratos/courier" - templates "github.com/ory/kratos/courier/template" - "github.com/ory/kratos/driver/config" - "github.com/ory/kratos/internal" ) // nolint:staticcheck @@ -33,132 +13,3 @@ func TestMain(m *testing.M) { atexit.Add(x.CleanUpTestSMTP) atexit.Exit(m.Run()) } - -func TestNewSMTP(t *testing.T) { - ctx := context.Background() - - setupConfig := func(stringURL string) *courier.Courier { - conf, reg := internal.NewFastRegistryWithMocks(t) - conf.MustSet(config.ViperKeyCourierSMTPURL, stringURL) - t.Logf("SMTP URL: %s", conf.CourierSMTPURL().String()) - return courier.NewSMTP(ctx, reg) - } - - if testing.Short() { - t.SkipNow() - } - - //Should enforce StartTLS => dialer.StartTLSPolicy = gomail.MandatoryStartTLS and dialer.SSL = false - smtp := setupConfig("smtp://foo:bar@my-server:1234/") - assert.Equal(t, smtp.Dialer.StartTLSPolicy, gomail.MandatoryStartTLS, "StartTLS not enforced") - assert.Equal(t, smtp.Dialer.SSL, false, "Implicit TLS should not be enabled") - - //Should enforce TLS => dialer.SSL = true - smtp = setupConfig("smtps://foo:bar@my-server:1234/") - assert.Equal(t, smtp.Dialer.SSL, true, "Implicit TLS should be enabled") - - //Should allow cleartext => dialer.StartTLSPolicy = gomail.OpportunisticStartTLS and dialer.SSL = false - smtp = setupConfig("smtp://foo:bar@my-server:1234/?disable_starttls=true") - assert.Equal(t, smtp.Dialer.StartTLSPolicy, gomail.OpportunisticStartTLS, "StartTLS is enforced") - assert.Equal(t, smtp.Dialer.SSL, false, "Implicit TLS should not be enabled") -} - -func TestSMTP(t *testing.T) { - if testing.Short() { - t.SkipNow() - } - - smtp, api, err := x.RunTestSMTP() - require.NoError(t, err) - t.Logf("SMTP URL: %s", smtp) - t.Logf("API URL: %s", api) - - ctx := context.Background() - - conf, reg := internal.NewFastRegistryWithMocks(t) - conf.MustSet(config.ViperKeyCourierSMTPURL, smtp) - conf.MustSet(config.ViperKeyCourierSMTPFrom, "test-stub@ory.sh") - reg.Logger().Level = logrus.TraceLevel - - c := reg.Courier(ctx) - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - id, err := c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ - To: "test-recipient-1@example.org", - Subject: "test-subject-1", - Body: "test-body-1", - })) - require.NoError(t, err) - require.NotEqual(t, uuid.Nil, id) - - id, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ - To: "test-recipient-2@example.org", - Subject: "test-subject-2", - Body: "test-body-2", - })) - require.NoError(t, err) - require.NotEqual(t, uuid.Nil, id) - - // The third email contains a sender name and custom headers - conf.MustSet(config.ViperKeyCourierSMTPFromName, "Bob") - conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header1", "foo") - conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header2", "bar") - customerHeaders := conf.CourierSMTPHeaders() - require.Len(t, customerHeaders, 2) - id, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ - To: "test-recipient-3@example.org", - Subject: "test-subject-3", - Body: "test-body-3", - })) - require.NoError(t, err) - require.NotEqual(t, uuid.Nil, id) - - go func() { - require.NoError(t, c.Work(ctx)) - }() - - var body []byte - for k := 0; k < 30; k++ { - time.Sleep(time.Second) - err = func() error { - res, err := http.Get(api + "/api/v2/messages") - if err != nil { - return err - } - - defer res.Body.Close() - body, err = ioutil.ReadAll(res.Body) - if err != nil { - return err - } - - if http.StatusOK != res.StatusCode { - return errors.Errorf("expected status code 200 but got %d with body: %s", res.StatusCode, body) - } - - if total := gjson.GetBytes(body, "total").Int(); total != 3 { - return errors.Errorf("expected to have delivered at least 3 messages but got count %d with body: %s", total, body) - } - - return nil - }() - if err == nil { - break - } - } - require.NoError(t, err) - - for k := 1; k <= 3; k++ { - assert.Contains(t, string(body), fmt.Sprintf("test-subject-%d", k)) - assert.Contains(t, string(body), fmt.Sprintf("test-body-%d", k)) - assert.Contains(t, string(body), fmt.Sprintf("test-recipient-%d@example.org", k)) - assert.Contains(t, string(body), "test-stub@ory.sh") - } - - // Assertion for the third email with sender name and headers - assert.Contains(t, string(body), "Bob") - assert.Contains(t, string(body), `"test-stub-header1":["foo"]`) - assert.Contains(t, string(body), `"test-stub-header2":["bar"]`) -} diff --git a/courier/templates.go b/courier/email_templates.go similarity index 65% rename from courier/templates.go rename to courier/email_templates.go index cbadd5731062..2ea3ea3bdf30 100644 --- a/courier/templates.go +++ b/courier/email_templates.go @@ -6,11 +6,12 @@ import ( "github.com/pkg/errors" - "github.com/ory/kratos/courier/template" + "github.com/ory/kratos/courier/template/email" ) type ( - TemplateType string + TemplateType string + EmailTemplate interface { json.Marshaler EmailSubject(context.Context) (string, error) @@ -25,58 +26,59 @@ const ( TypeRecoveryValid TemplateType = "recovery_valid" TypeVerificationInvalid TemplateType = "verification_invalid" TypeVerificationValid TemplateType = "verification_valid" + TypeOTP TemplateType = "otp" TypeTestStub TemplateType = "stub" ) -func GetTemplateType(t EmailTemplate) (TemplateType, error) { +func GetEmailTemplateType(t EmailTemplate) (TemplateType, error) { switch t.(type) { - case *template.RecoveryInvalid: + case *email.RecoveryInvalid: return TypeRecoveryInvalid, nil - case *template.RecoveryValid: + case *email.RecoveryValid: return TypeRecoveryValid, nil - case *template.VerificationInvalid: + case *email.VerificationInvalid: return TypeVerificationInvalid, nil - case *template.VerificationValid: + case *email.VerificationValid: return TypeVerificationValid, nil - case *template.TestStub: + case *email.TestStub: return TypeTestStub, nil default: return "", errors.Errorf("unexpected template type") } } -func NewEmailTemplateFromMessage(d SMTPDependencies, msg Message) (EmailTemplate, error) { +func NewEmailTemplateFromMessage(d Dependencies, msg Message) (EmailTemplate, error) { switch msg.TemplateType { case TypeRecoveryInvalid: - var t template.RecoveryInvalidModel + var t email.RecoveryInvalidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewRecoveryInvalid(d, &t), nil + return email.NewRecoveryInvalid(d, &t), nil case TypeRecoveryValid: - var t template.RecoveryValidModel + var t email.RecoveryValidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewRecoveryValid(d, &t), nil + return email.NewRecoveryValid(d, &t), nil case TypeVerificationInvalid: - var t template.VerificationInvalidModel + var t email.VerificationInvalidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewVerificationInvalid(d, &t), nil + return email.NewVerificationInvalid(d, &t), nil case TypeVerificationValid: - var t template.VerificationValidModel + var t email.VerificationValidModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewVerificationValid(d, &t), nil + return email.NewVerificationValid(d, &t), nil case TypeTestStub: - var t template.TestStubModel + var t email.TestStubModel if err := json.Unmarshal(msg.TemplateData, &t); err != nil { return nil, err } - return template.NewTestStub(d, &t), nil + return email.NewTestStub(d, &t), nil default: return nil, errors.Errorf("received unexpected message template type: %s", msg.TemplateType) } diff --git a/courier/templates_test.go b/courier/email_templates_test.go similarity index 64% rename from courier/templates_test.go rename to courier/email_templates_test.go index e41fa0705bdc..e6b97885e36d 100644 --- a/courier/templates_test.go +++ b/courier/email_templates_test.go @@ -9,25 +9,23 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/kratos/courier" - "github.com/ory/kratos/courier/template" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/internal" ) func TestGetTemplateType(t *testing.T) { for expectedType, tmpl := range map[courier.TemplateType]courier.EmailTemplate{ - courier.TypeRecoveryInvalid: &template.RecoveryInvalid{}, - courier.TypeRecoveryValid: &template.RecoveryValid{}, - courier.TypeVerificationInvalid: &template.VerificationInvalid{}, - courier.TypeVerificationValid: &template.VerificationValid{}, - courier.TypeTestStub: &template.TestStub{}, + courier.TypeRecoveryInvalid: &email.RecoveryInvalid{}, + courier.TypeRecoveryValid: &email.RecoveryValid{}, + courier.TypeVerificationInvalid: &email.VerificationInvalid{}, + courier.TypeVerificationValid: &email.VerificationValid{}, + courier.TypeTestStub: &email.TestStub{}, } { t.Run(fmt.Sprintf("case=%s", expectedType), func(t *testing.T) { - actualType, err := courier.GetTemplateType(tmpl) + actualType, err := courier.GetEmailTemplateType(tmpl) require.NoError(t, err) require.Equal(t, expectedType, actualType) - }) - } } @@ -36,11 +34,11 @@ func TestNewEmailTemplateFromMessage(t *testing.T) { ctx := context.Background() for tmplType, expectedTmpl := range map[courier.TemplateType]courier.EmailTemplate{ - courier.TypeRecoveryInvalid: template.NewRecoveryInvalid(reg, &template.RecoveryInvalidModel{To: "foo"}), - courier.TypeRecoveryValid: template.NewRecoveryValid(reg, &template.RecoveryValidModel{To: "bar", RecoveryURL: "http://foo.bar"}), - courier.TypeVerificationInvalid: template.NewVerificationInvalid(reg, &template.VerificationInvalidModel{To: "baz"}), - courier.TypeVerificationValid: template.NewVerificationValid(reg, &template.VerificationValidModel{To: "faz", VerificationURL: "http://bar.foo"}), - courier.TypeTestStub: template.NewTestStub(reg, &template.TestStubModel{To: "far", Subject: "test subject", Body: "test body"}), + courier.TypeRecoveryInvalid: email.NewRecoveryInvalid(reg, &email.RecoveryInvalidModel{To: "foo"}), + courier.TypeRecoveryValid: email.NewRecoveryValid(reg, &email.RecoveryValidModel{To: "bar", RecoveryURL: "http://foo.bar"}), + courier.TypeVerificationInvalid: email.NewVerificationInvalid(reg, &email.VerificationInvalidModel{To: "baz"}), + courier.TypeVerificationValid: email.NewVerificationValid(reg, &email.VerificationValidModel{To: "faz", VerificationURL: "http://bar.foo"}), + courier.TypeTestStub: email.NewTestStub(reg, &email.TestStubModel{To: "far", Subject: "test subject", Body: "test body"}), } { t.Run(fmt.Sprintf("case=%s", tmplType), func(t *testing.T) { tmplData, err := json.Marshal(expectedTmpl) diff --git a/courier/message.go b/courier/message.go index 94b102781cb2..0641a0d49b85 100644 --- a/courier/message.go +++ b/courier/message.go @@ -4,9 +4,9 @@ import ( "context" "time" - "github.com/ory/kratos/corp" - "github.com/gofrs/uuid" + + "github.com/ory/kratos/corp" ) type MessageStatus int @@ -21,6 +21,7 @@ type MessageType int const ( MessageTypeEmail MessageType = iota + 1 + MessageTypePhone ) // swagger:ignore diff --git a/courier/sms.go b/courier/sms.go new file mode 100644 index 000000000000..1436a6062abf --- /dev/null +++ b/courier/sms.go @@ -0,0 +1,113 @@ +package courier + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/pkg/errors" + + "github.com/ory/herodot" + + "github.com/gofrs/uuid" + + "github.com/ory/kratos/request" +) + +type sendSMSRequestBody struct { + From string `json:"from"` + To string `json:"to"` + Body string `json:"body"` +} + +type smsClient struct { + RequestConfig json.RawMessage + + GetTemplateType func(t SMSTemplate) (TemplateType, error) + NewTemplateFromMessage func(d Dependencies, msg Message) (SMSTemplate, error) +} + +func newSMS(ctx context.Context, deps Dependencies) *smsClient { + return &smsClient{ + RequestConfig: deps.CourierConfig(ctx).CourierSMSRequestConfig(), + + GetTemplateType: SMSTemplateType, + NewTemplateFromMessage: NewSMSTemplateFromMessage, + } +} + +func (c *courier) QueueSMS(ctx context.Context, t SMSTemplate) (uuid.UUID, error) { + recipient, err := t.PhoneNumber() + if err != nil { + return uuid.Nil, err + } + + templateType, err := c.smsClient.GetTemplateType(t) + if err != nil { + return uuid.Nil, err + } + + templateData, err := json.Marshal(t) + if err != nil { + return uuid.Nil, err + } + + message := &Message{ + Status: MessageStatusQueued, + Type: MessageTypePhone, + Recipient: recipient, + TemplateType: templateType, + TemplateData: templateData, + } + if err := c.deps.CourierPersister().AddMessage(ctx, message); err != nil { + return uuid.Nil, err + } + + return message.ID, nil +} + +func (c *courier) dispatchSMS(ctx context.Context, msg Message) error { + if !c.deps.CourierConfig(ctx).CourierSMSEnabled() { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an sms but courier.sms.enabled is set to false!")) + } + + tmpl, err := c.smsClient.NewTemplateFromMessage(c.deps, msg) + if err != nil { + return err + } + + body, err := tmpl.SMSBody(ctx) + if err != nil { + return err + } + + builder, err := request.NewBuilder(c.smsClient.RequestConfig, c.deps.HTTPClient(ctx), c.deps.Logger()) + if err != nil { + return err + } + + req, err := builder.BuildRequest(&sendSMSRequestBody{ + To: msg.Recipient, + From: c.deps.CourierConfig(ctx).CourierSMSFrom(), + Body: body, + }) + if err != nil { + return err + } + + res, err := c.deps.HTTPClient(ctx).Do(req) + if err != nil { + return err + } + + defer res.Body.Close() + + switch res.StatusCode { + case http.StatusOK: + case http.StatusCreated: + default: + return errors.New(http.StatusText(res.StatusCode)) + } + + return nil +} diff --git a/courier/sms_templates.go b/courier/sms_templates.go new file mode 100644 index 000000000000..079268bd8e1a --- /dev/null +++ b/courier/sms_templates.go @@ -0,0 +1,46 @@ +package courier + +import ( + "context" + "encoding/json" + + "github.com/pkg/errors" + + "github.com/ory/kratos/courier/template/sms" +) + +type SMSTemplate interface { + json.Marshaler + SMSBody(context.Context) (string, error) + PhoneNumber() (string, error) +} + +func SMSTemplateType(t SMSTemplate) (TemplateType, error) { + switch t.(type) { + case *sms.OTPMessage: + return TypeOTP, nil + case *sms.TestStub: + return TypeTestStub, nil + default: + return "", errors.Errorf("unexpected template type") + } +} + +func NewSMSTemplateFromMessage(d Dependencies, m Message) (SMSTemplate, error) { + switch m.TemplateType { + case TypeOTP: + var t sms.OTPMessageModel + if err := json.Unmarshal(m.TemplateData, &t); err != nil { + return nil, err + } + return sms.NewOTPMessage(d, &t), nil + case TypeTestStub: + var t sms.TestStubModel + if err := json.Unmarshal(m.TemplateData, &t); err != nil { + return nil, err + } + return sms.NewTestStub(d, &t), nil + default: + return nil, errors.Errorf("received unexpected message template type: %s", m.TemplateType) + } +} diff --git a/courier/sms_templates_test.go b/courier/sms_templates_test.go new file mode 100644 index 000000000000..760f89a21e04 --- /dev/null +++ b/courier/sms_templates_test.go @@ -0,0 +1,60 @@ +package courier_test + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier" + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/internal" +) + +func TestSMSTemplateType(t *testing.T) { + for expectedType, tmpl := range map[courier.TemplateType]courier.SMSTemplate{ + courier.TypeOTP: &sms.OTPMessage{}, + courier.TypeTestStub: &sms.TestStub{}, + } { + t.Run(fmt.Sprintf("case=%s", expectedType), func(t *testing.T) { + actualType, err := courier.SMSTemplateType(tmpl) + require.NoError(t, err) + require.Equal(t, expectedType, actualType) + }) + } +} + +func TestNewSMSTemplateFromMessage(t *testing.T) { + _, reg := internal.NewFastRegistryWithMocks(t) + ctx := context.Background() + + for tmplType, expectedTmpl := range map[courier.TemplateType]courier.SMSTemplate{ + courier.TypeOTP: sms.NewOTPMessage(reg, &sms.OTPMessageModel{To: "+12345678901"}), + courier.TypeTestStub: sms.NewTestStub(reg, &sms.TestStubModel{To: "+12345678901", Body: "test body"}), + } { + t.Run(fmt.Sprintf("case=%s", tmplType), func(t *testing.T) { + tmplData, err := json.Marshal(expectedTmpl) + require.NoError(t, err) + + m := courier.Message{TemplateType: tmplType, TemplateData: tmplData} + actualTmpl, err := courier.NewSMSTemplateFromMessage(reg, m) + require.NoError(t, err) + + require.IsType(t, expectedTmpl, actualTmpl) + + expectedRecipient, err := expectedTmpl.PhoneNumber() + require.NoError(t, err) + actualRecipient, err := actualTmpl.PhoneNumber() + require.NoError(t, err) + require.Equal(t, expectedRecipient, actualRecipient) + + expectedBody, err := expectedTmpl.SMSBody(ctx) + require.NoError(t, err) + actualBody, err := actualTmpl.SMSBody(ctx) + require.NoError(t, err) + require.Equal(t, expectedBody, actualBody) + }) + } +} diff --git a/courier/sms_test.go b/courier/sms_test.go new file mode 100644 index 000000000000..fdcd234f4ce3 --- /dev/null +++ b/courier/sms_test.go @@ -0,0 +1,144 @@ +package courier_test + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/driver/config" + "github.com/ory/kratos/internal" + "github.com/ory/x/resilience" +) + +func TestQueueSMS(t *testing.T) { + expectedSender := "Kratos Test" + expectedSMS := []*sms.TestStubModel{ + { + To: "+12065550101", + Body: "test-sms-body-1", + }, + { + To: "+12065550102", + Body: "test-sms-body-2", + }, + } + + actual := make([]*sms.TestStubModel, 0, 2) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + type sendSMSRequestBody struct { + To string + From string + Body string + } + + rb, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + + var body sendSMSRequestBody + + err = json.Unmarshal(rb, &body) + require.NoError(t, err) + + assert.NotEmpty(t, r.Header["Authorization"]) + assert.Equal(t, "Basic bWU6MTIzNDU=", r.Header["Authorization"][0]) + + assert.Equal(t, body.From, expectedSender) + actual = append(actual, &sms.TestStubModel{ + To: body.To, + Body: body.Body, + }) + })) + + requestConfig := fmt.Sprintf(`{ + "url": "%s", + "method": "POST", + "body": "file://./stub/request.config.twilio.jsonnet", + "auth": { + "type": "basic_auth", + "config": { + "user": "me", + "password": "12345" + } + } + }`, srv.URL) + + conf, reg := internal.NewFastRegistryWithMocks(t) + conf.MustSet(config.ViperKeyCourierSMSRequestConfig, requestConfig) + conf.MustSet(config.ViperKeyCourierSMSFrom, expectedSender) + conf.MustSet(config.ViperKeyCourierSMSEnabled, true) + conf.MustSet(config.ViperKeyCourierSMTPURL, "http://foo.url") + reg.Logger().Level = logrus.TraceLevel + + ctx := context.Background() + + c := reg.Courier(ctx) + + ctx, cancel := context.WithCancel(ctx) + defer t.Cleanup(cancel) + + for _, message := range expectedSMS { + id, err := c.QueueSMS(ctx, sms.NewTestStub(reg, message)) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, id) + } + + go func() { + require.NoError(t, c.Work(ctx)) + }() + + require.NoError(t, resilience.Retry(reg.Logger(), time.Millisecond*250, time.Second*10, func() error { + if len(actual) == len(expectedSMS) { + return nil + } + return errors.New("capacity not reached") + })) + + for i, message := range actual { + expected := expectedSMS[i] + + assert.Equal(t, expected.To, message.To) + assert.Equal(t, fmt.Sprintf("stub sms body %s\n", expected.Body), message.Body) + } + + srv.Close() +} + +func TestDisallowedInternalNetwork(t *testing.T) { + conf, reg := internal.NewFastRegistryWithMocks(t) + conf.MustSet(config.ViperKeyCourierSMSRequestConfig, fmt.Sprintf(`{ + "url": "http://127.0.0.1/", + "method": "GET", + "body": "file://./stub/request.config.twilio.jsonnet" + }`)) + conf.MustSet(config.ViperKeyCourierSMSEnabled, true) + conf.MustSet(config.ViperKeyCourierSMTPURL, "http://foo.url") + conf.MustSet(config.ViperKeyClientHTTPNoPrivateIPRanges, true) + reg.Logger().Level = logrus.TraceLevel + + ctx := context.Background() + c := reg.Courier(ctx) + c.(interface { + FailOnDispatchError() + }).FailOnDispatchError() + _, err := c.QueueSMS(ctx, sms.NewTestStub(reg, &sms.TestStubModel{ + To: "+12065550101", + Body: "test-sms-body-1", + })) + require.NoError(t, err) + + err = c.DispatchQueue(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "ip 127.0.0.1 is in the 127.0.0.0/8 range") +} diff --git a/courier/smtp.go b/courier/smtp.go new file mode 100644 index 000000000000..8d5e19f83be2 --- /dev/null +++ b/courier/smtp.go @@ -0,0 +1,182 @@ +package courier + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/ory/kratos/driver/config" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + + "github.com/ory/herodot" + gomail "github.com/ory/mail/v3" +) + +type smtpClient struct { + *gomail.Dialer + + GetTemplateType func(t EmailTemplate) (TemplateType, error) + NewTemplateFromMessage func(d Dependencies, msg Message) (EmailTemplate, error) +} + +func newSMTP(ctx context.Context, deps Dependencies) *smtpClient { + uri := deps.CourierConfig(ctx).CourierSMTPURL() + + password, _ := uri.User.Password() + port, _ := strconv.ParseInt(uri.Port(), 10, 0) + + dialer := &gomail.Dialer{ + Host: uri.Hostname(), + Port: int(port), + Username: uri.User.Username(), + Password: password, + + Timeout: time.Second * 10, + RetryFailure: true, + } + + sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify")) + + // SMTP schemes + // smtp: smtp clear text (with uri parameter) or with StartTLS (enforced by default) + // smtps: smtp with implicit TLS (recommended way in 2021 to avoid StartTLS downgrade attacks + // and defaulting to fully-encrypted protocols https://datatracker.ietf.org/doc/html/rfc8314) + switch uri.Scheme { + case "smtp": + // Enforcing StartTLS by default for security best practices (config review, etc.) + skipStartTLS, _ := strconv.ParseBool(uri.Query().Get("disable_starttls")) + if !skipStartTLS { + // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. + dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} + // Enforcing StartTLS + dialer.StartTLSPolicy = gomail.MandatoryStartTLS + } + case "smtps": + // #nosec G402 This is ok (and required!) because it is configurable and disabled by default. + dialer.TLSConfig = &tls.Config{InsecureSkipVerify: sslSkipVerify, ServerName: uri.Hostname()} + dialer.SSL = true + } + + return &smtpClient{ + Dialer: dialer, + + GetTemplateType: GetEmailTemplateType, + NewTemplateFromMessage: NewEmailTemplateFromMessage, + } +} + +func (c *courier) SmtpDialer() *gomail.Dialer { + return c.smtpClient.Dialer +} + +func (c *courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, error) { + recipient, err := t.EmailRecipient() + if err != nil { + return uuid.Nil, err + } + + subject, err := t.EmailSubject(ctx) + if err != nil { + return uuid.Nil, err + } + + bodyPlaintext, err := t.EmailBodyPlaintext(ctx) + if err != nil { + return uuid.Nil, err + } + + templateType, err := c.smtpClient.GetTemplateType(t) + if err != nil { + return uuid.Nil, err + } + + templateData, err := json.Marshal(t) + if err != nil { + return uuid.Nil, err + } + + message := &Message{ + Status: MessageStatusQueued, + Type: MessageTypeEmail, + Recipient: recipient, + Body: bodyPlaintext, + Subject: subject, + TemplateType: templateType, + TemplateData: templateData, + } + + if err := c.deps.CourierPersister().AddMessage(ctx, message); err != nil { + return uuid.Nil, err + } + + return message.ID, nil +} + +func (c *courier) dispatchEmail(ctx context.Context, msg Message) error { + if c.smtpClient.Host == "" { + return errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Courier tried to deliver an email but %s is not set!", config.ViperKeyCourierSMTPURL)) + } + + from := c.deps.CourierConfig(ctx).CourierSMTPFrom() + fromName := c.deps.CourierConfig(ctx).CourierSMTPFromName() + + gm := gomail.NewMessage() + if fromName == "" { + gm.SetHeader("From", from) + } else { + gm.SetAddressHeader("From", from, fromName) + } + + gm.SetHeader("To", msg.Recipient) + gm.SetHeader("Subject", msg.Subject) + + headers := c.deps.CourierConfig(ctx).CourierSMTPHeaders() + for k, v := range headers { + gm.SetHeader(k, v) + } + + gm.SetBody("text/plain", msg.Body) + + tmpl, err := c.smtpClient.NewTemplateFromMessage(c.deps, msg) + if err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to get email template from message.`) + } else { + htmlBody, err := tmpl.EmailBody(ctx) + if err != nil { + c.deps.Logger(). + WithError(err). + WithField("message_id", msg.ID). + Error(`Unable to get email body from template.`) + } else { + gm.AddAlternative("text/html", htmlBody) + } + } + + if err := c.smtpClient.DialAndSend(ctx, gm); err != nil { + c.deps.Logger(). + WithError(err). + WithField("smtp_server", fmt.Sprintf("%s:%d", c.smtpClient.Host, c.smtpClient.Port)). + WithField("smtp_ssl_enabled", c.smtpClient.SSL). + // WithField("email_to", msg.Recipient). + WithField("message_from", from). + Error("Unable to send email using SMTP connection.") + return errors.WithStack(err) + } + + c.deps.Logger(). + WithField("message_id", msg.ID). + WithField("message_type", msg.Type). + WithField("message_template_type", msg.TemplateType). + WithField("message_subject", msg.Subject). + Debug("Courier sent out message.") + + return nil +} diff --git a/courier/smtp_test.go b/courier/smtp_test.go new file mode 100644 index 000000000000..626d2c45d175 --- /dev/null +++ b/courier/smtp_test.go @@ -0,0 +1,155 @@ +package courier_test + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + + "github.com/ory/kratos/courier" + templates "github.com/ory/kratos/courier/template/email" + "github.com/ory/kratos/driver/config" + "github.com/ory/kratos/internal" + "github.com/ory/kratos/x" + gomail "github.com/ory/mail/v3" +) + +func TestNewSMTP(t *testing.T) { + ctx := context.Background() + + setupConfig := func(stringURL string) courier.Courier { + conf, reg := internal.NewFastRegistryWithMocks(t) + conf.MustSet(config.ViperKeyCourierSMTPURL, stringURL) + + t.Logf("SMTP URL: %s", conf.CourierSMTPURL().String()) + + return courier.NewCourier(ctx, reg) + } + + if testing.Short() { + t.SkipNow() + } + + //Should enforce StartTLS => dialer.StartTLSPolicy = gomail.MandatoryStartTLS and dialer.SSL = false + smtp := setupConfig("smtp://foo:bar@my-server:1234/") + assert.Equal(t, smtp.SmtpDialer().StartTLSPolicy, gomail.MandatoryStartTLS, "StartTLS not enforced") + assert.Equal(t, smtp.SmtpDialer().SSL, false, "Implicit TLS should not be enabled") + + //Should enforce TLS => dialer.SSL = true + smtp = setupConfig("smtps://foo:bar@my-server:1234/") + assert.Equal(t, smtp.SmtpDialer().SSL, true, "Implicit TLS should be enabled") + + //Should allow cleartext => dialer.StartTLSPolicy = gomail.OpportunisticStartTLS and dialer.SSL = false + smtp = setupConfig("smtp://foo:bar@my-server:1234/?disable_starttls=true") + assert.Equal(t, smtp.SmtpDialer().StartTLSPolicy, gomail.OpportunisticStartTLS, "StartTLS is enforced") + assert.Equal(t, smtp.SmtpDialer().SSL, false, "Implicit TLS should not be enabled") +} + +func TestQueueEmail(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + + smtp, api, err := x.RunTestSMTP() + require.NoError(t, err) + t.Logf("SMTP URL: %s", smtp) + t.Logf("API URL: %s", api) + + ctx := context.Background() + + conf, reg := internal.NewRegistryDefaultWithDSN(t, "") + conf.MustSet(config.ViperKeyCourierSMTPURL, smtp) + conf.MustSet(config.ViperKeyCourierSMTPFrom, "test-stub@ory.sh") + reg.Logger().Level = logrus.TraceLevel + + c := reg.Courier(ctx) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + id, err := c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ + To: "test-recipient-1@example.org", + Subject: "test-subject-1", + Body: "test-body-1", + })) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, id) + + id, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ + To: "test-recipient-2@example.org", + Subject: "test-subject-2", + Body: "test-body-2", + })) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, id) + + // The third email contains a sender name and custom headers + conf.MustSet(config.ViperKeyCourierSMTPFromName, "Bob") + conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header1", "foo") + conf.MustSet(config.ViperKeyCourierSMTPHeaders+".test-stub-header2", "bar") + customerHeaders := conf.CourierSMTPHeaders() + require.Len(t, customerHeaders, 2) + id, err = c.QueueEmail(ctx, templates.NewTestStub(reg, &templates.TestStubModel{ + To: "test-recipient-3@example.org", + Subject: "test-subject-3", + Body: "test-body-3", + })) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, id) + + go func() { + require.NoError(t, c.Work(ctx)) + }() + + var body []byte + for k := 0; k < 30; k++ { + time.Sleep(time.Second) + err = func() error { + res, err := http.Get(api + "/api/v2/messages") + if err != nil { + return err + } + + defer res.Body.Close() + body, err = ioutil.ReadAll(res.Body) + if err != nil { + return err + } + + if http.StatusOK != res.StatusCode { + return errors.Errorf("expected status code 200 but got %d with body: %s", res.StatusCode, body) + } + + if total := gjson.GetBytes(body, "total").Int(); total != 3 { + return errors.Errorf("expected to have delivered at least 3 messages but got count %d with body: %s", total, body) + } + + return nil + }() + if err == nil { + break + } + } + require.NoError(t, err) + + for k := 1; k <= 3; k++ { + assert.Contains(t, string(body), fmt.Sprintf("test-subject-%d", k)) + assert.Contains(t, string(body), fmt.Sprintf("test-body-%d", k)) + assert.Contains(t, string(body), fmt.Sprintf("test-recipient-%d@example.org", k)) + assert.Contains(t, string(body), "test-stub@ory.sh") + } + + // Assertion for the third email with sender name and headers + assert.Contains(t, string(body), "Bob") + assert.Contains(t, string(body), `"test-stub-header1":["foo"]`) + assert.Contains(t, string(body), `"test-stub-header2":["bar"]`) +} diff --git a/courier/stub/request.config.twilio.jsonnet b/courier/stub/request.config.twilio.jsonnet new file mode 100644 index 000000000000..da0736b06df0 --- /dev/null +++ b/courier/stub/request.config.twilio.jsonnet @@ -0,0 +1,5 @@ +function(ctx) { + from: ctx.from, + to: ctx.to, + body: ctx.body +} diff --git a/courier/template/courier/builtin/templates/otp/sms.body.gotmpl b/courier/template/courier/builtin/templates/otp/sms.body.gotmpl new file mode 100644 index 000000000000..ff95187e9e7f --- /dev/null +++ b/courier/template/courier/builtin/templates/otp/sms.body.gotmpl @@ -0,0 +1 @@ +Your verification code is: {{ .Code }} diff --git a/courier/template/courier/builtin/templates/otp/test_stub/sms.body.gotmpl b/courier/template/courier/builtin/templates/otp/test_stub/sms.body.gotmpl new file mode 100644 index 000000000000..a37e4640152d --- /dev/null +++ b/courier/template/courier/builtin/templates/otp/test_stub/sms.body.gotmpl @@ -0,0 +1 @@ +stub sms body {{ .Body }} diff --git a/courier/template/email/recovery_invalid.go b/courier/template/email/recovery_invalid.go new file mode 100644 index 000000000000..25c20c2095ef --- /dev/null +++ b/courier/template/email/recovery_invalid.go @@ -0,0 +1,43 @@ +package email + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + RecoveryInvalid struct { + d template.Dependencies + m *RecoveryInvalidModel + } + RecoveryInvalidModel struct { + To string + } +) + +func NewRecoveryInvalid(d template.Dependencies, m *RecoveryInvalidModel) *RecoveryInvalid { + return &RecoveryInvalid{d: d, m: m} +} + +func (t *RecoveryInvalid) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *RecoveryInvalid) EmailSubject(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.subject.gotmpl", "recovery/invalid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Subject) +} + +func (t *RecoveryInvalid) EmailBody(ctx context.Context) (string, error) { + return template.LoadHTML(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.body.gotmpl", "recovery/invalid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Body.HTML) +} + +func (t *RecoveryInvalid) EmailBodyPlaintext(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.body.plaintext.gotmpl", "recovery/invalid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Body.PlainText) +} + +func (t *RecoveryInvalid) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/recovery_invalid_test.go b/courier/template/email/recovery_invalid_test.go similarity index 66% rename from courier/template/recovery_invalid_test.go rename to courier/template/email/recovery_invalid_test.go index aa5b1c83f0c2..d3d533ab53af 100644 --- a/courier/template/recovery_invalid_test.go +++ b/courier/template/email/recovery_invalid_test.go @@ -1,13 +1,12 @@ -package template_test +package email_test import ( "context" "testing" "github.com/ory/kratos/courier" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/courier/template/testhelpers" - - "github.com/ory/kratos/courier/template" "github.com/ory/kratos/internal" ) @@ -17,12 +16,12 @@ func TestRecoverInvalid(t *testing.T) { t.Run("test=with courier templates directory", func(t *testing.T) { _, reg := internal.NewFastRegistryWithMocks(t) - tpl := template.NewRecoveryInvalid(reg, &template.RecoveryInvalidModel{}) + tpl := email.NewRecoveryInvalid(reg, &email.RecoveryInvalidModel{}) testhelpers.TestRendered(t, ctx, tpl) }) t.Run("case=test remote resources", func(t *testing.T) { - testhelpers.TestRemoteTemplates(t, "courier/builtin/templates/recovery/invalid", courier.TypeRecoveryInvalid) + testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery/invalid", courier.TypeRecoveryInvalid) }) } diff --git a/courier/template/email/recovery_valid.go b/courier/template/email/recovery_valid.go new file mode 100644 index 000000000000..65ce00f27c0e --- /dev/null +++ b/courier/template/email/recovery_valid.go @@ -0,0 +1,45 @@ +package email + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + RecoveryValid struct { + d template.Dependencies + m *RecoveryValidModel + } + RecoveryValidModel struct { + To string + RecoveryURL string + Identity map[string]interface{} + } +) + +func NewRecoveryValid(d template.Dependencies, m *RecoveryValidModel) *RecoveryValid { + return &RecoveryValid{d: d, m: m} +} + +func (t *RecoveryValid) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *RecoveryValid) EmailSubject(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.subject.gotmpl", "recovery/valid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Subject) +} + +func (t *RecoveryValid) EmailBody(ctx context.Context) (string, error) { + return template.LoadHTML(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.body.gotmpl", "recovery/valid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Body.HTML) +} + +func (t *RecoveryValid) EmailBodyPlaintext(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.body.plaintext.gotmpl", "recovery/valid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Body.PlainText) +} + +func (t *RecoveryValid) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/recovery_valid_test.go b/courier/template/email/recovery_valid_test.go similarity index 67% rename from courier/template/recovery_valid_test.go rename to courier/template/email/recovery_valid_test.go index 0de24aa4b9dc..0264fba9a4df 100644 --- a/courier/template/recovery_valid_test.go +++ b/courier/template/email/recovery_valid_test.go @@ -1,13 +1,12 @@ -package template_test +package email_test import ( "context" "testing" "github.com/ory/kratos/courier" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/courier/template/testhelpers" - - "github.com/ory/kratos/courier/template" "github.com/ory/kratos/internal" ) @@ -17,12 +16,12 @@ func TestRecoverValid(t *testing.T) { t.Run("test=with courier templates directory", func(t *testing.T) { _, reg := internal.NewFastRegistryWithMocks(t) - tpl := template.NewRecoveryValid(reg, &template.RecoveryValidModel{}) + tpl := email.NewRecoveryValid(reg, &email.RecoveryValidModel{}) testhelpers.TestRendered(t, ctx, tpl) }) t.Run("test=with remote resources", func(t *testing.T) { - testhelpers.TestRemoteTemplates(t, "courier/builtin/templates/recovery/valid", courier.TypeRecoveryValid) + testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/recovery/valid", courier.TypeRecoveryValid) }) } diff --git a/courier/template/email/stub.go b/courier/template/email/stub.go new file mode 100644 index 000000000000..e5cecaf657a8 --- /dev/null +++ b/courier/template/email/stub.go @@ -0,0 +1,45 @@ +package email + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + TestStub struct { + d template.Dependencies + m *TestStubModel + } + TestStubModel struct { + To string + Subject string + Body string + } +) + +func NewTestStub(d template.Dependencies, m *TestStubModel) *TestStub { + return &TestStub{d: d, m: m} +} + +func (t *TestStub) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *TestStub) EmailSubject(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.subject.gotmpl", "test_stub/email.subject*", t.m, "") +} + +func (t *TestStub) EmailBody(ctx context.Context) (string, error) { + return template.LoadHTML(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.body.gotmpl", "test_stub/email.body*", t.m, "") +} + +func (t *TestStub) EmailBodyPlaintext(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.body.plaintext.gotmpl", "test_stub/email.body.plaintext*", t.m, "") +} + +func (t *TestStub) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/email/verification_invalid.go b/courier/template/email/verification_invalid.go new file mode 100644 index 000000000000..f153c13aa922 --- /dev/null +++ b/courier/template/email/verification_invalid.go @@ -0,0 +1,43 @@ +package email + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + VerificationInvalid struct { + d template.Dependencies + m *VerificationInvalidModel + } + VerificationInvalidModel struct { + To string + } +) + +func NewVerificationInvalid(d template.Dependencies, m *VerificationInvalidModel) *VerificationInvalid { + return &VerificationInvalid{d: d, m: m} +} + +func (t *VerificationInvalid) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *VerificationInvalid) EmailSubject(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.subject.gotmpl", "verification/invalid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Subject) +} + +func (t *VerificationInvalid) EmailBody(ctx context.Context) (string, error) { + return template.LoadHTML(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.body.gotmpl", "verification/invalid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Body.HTML) +} + +func (t *VerificationInvalid) EmailBodyPlaintext(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.body.plaintext.gotmpl", "verification/invalid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Body.PlainText) +} + +func (t *VerificationInvalid) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/verification_invalid_test.go b/courier/template/email/verification_invalid_test.go similarity index 67% rename from courier/template/verification_invalid_test.go rename to courier/template/email/verification_invalid_test.go index 8bbb9972e582..15a837e09968 100644 --- a/courier/template/verification_invalid_test.go +++ b/courier/template/email/verification_invalid_test.go @@ -1,13 +1,12 @@ -package template_test +package email_test import ( "context" "testing" "github.com/ory/kratos/courier" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/courier/template/testhelpers" - - "github.com/ory/kratos/courier/template" "github.com/ory/kratos/internal" ) @@ -17,14 +16,14 @@ func TestVerifyInvalid(t *testing.T) { t.Run("test=with courier templates directory", func(t *testing.T) { _, reg := internal.NewFastRegistryWithMocks(t) - tpl := template.NewVerificationInvalid(reg, &template.VerificationInvalidModel{}) + tpl := email.NewVerificationInvalid(reg, &email.VerificationInvalidModel{}) testhelpers.TestRendered(t, ctx, tpl) }) t.Run("test=with remote resources", func(t *testing.T) { t.Run("test=with remote resources", func(t *testing.T) { - testhelpers.TestRemoteTemplates(t, "courier/builtin/templates/verification/invalid", courier.TypeVerificationInvalid) + testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification/invalid", courier.TypeVerificationInvalid) }) }) } diff --git a/courier/template/email/verification_valid.go b/courier/template/email/verification_valid.go new file mode 100644 index 000000000000..3de84840bdbe --- /dev/null +++ b/courier/template/email/verification_valid.go @@ -0,0 +1,45 @@ +package email + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + VerificationValid struct { + d template.Dependencies + m *VerificationValidModel + } + VerificationValidModel struct { + To string + VerificationURL string + Identity map[string]interface{} + } +) + +func NewVerificationValid(d template.Dependencies, m *VerificationValidModel) *VerificationValid { + return &VerificationValid{d: d, m: m} +} + +func (t *VerificationValid) EmailRecipient() (string, error) { + return t.m.To, nil +} + +func (t *VerificationValid) EmailSubject(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.subject.gotmpl", "verification/valid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Subject) +} + +func (t *VerificationValid) EmailBody(ctx context.Context) (string, error) { + return template.LoadHTML(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.body.gotmpl", "verification/valid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Body.HTML) +} + +func (t *VerificationValid) EmailBodyPlaintext(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.body.plaintext.gotmpl", "verification/valid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Body.PlainText) +} + +func (t *VerificationValid) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/verification_valid_test.go b/courier/template/email/verification_valid_test.go similarity index 65% rename from courier/template/verification_valid_test.go rename to courier/template/email/verification_valid_test.go index 2313c74d0fe2..1ce209445fea 100644 --- a/courier/template/verification_valid_test.go +++ b/courier/template/email/verification_valid_test.go @@ -1,13 +1,12 @@ -package template_test +package email_test import ( "context" "testing" "github.com/ory/kratos/courier" + "github.com/ory/kratos/courier/template/email" "github.com/ory/kratos/courier/template/testhelpers" - - "github.com/ory/kratos/courier/template" "github.com/ory/kratos/internal" ) @@ -17,12 +16,12 @@ func TestVerifyValid(t *testing.T) { t.Run("test=with courier templates directory", func(t *testing.T) { _, reg := internal.NewFastRegistryWithMocks(t) - tpl := template.NewVerificationValid(reg, &template.VerificationValidModel{}) + tpl := email.NewVerificationValid(reg, &email.VerificationValidModel{}) testhelpers.TestRendered(t, ctx, tpl) }) t.Run("test=with remote resources", func(t *testing.T) { - testhelpers.TestRemoteTemplates(t, "courier/builtin/templates/verification/valid", courier.TypeVerificationValid) + testhelpers.TestRemoteTemplates(t, "../courier/builtin/templates/verification/valid", courier.TypeVerificationValid) }) } diff --git a/courier/template/load_template.go b/courier/template/load_template.go index ceceb43e1960..d3e3194cbd38 100644 --- a/courier/template/load_template.go +++ b/courier/template/load_template.go @@ -149,7 +149,7 @@ func loadTemplate(filesystem fs.FS, name, pattern string, html bool) (Template, return tpl, nil } -func LoadTextTemplate(ctx context.Context, d templateDependencies, filesystem fs.FS, name, pattern string, model interface{}, remoteURL string) (string, error) { +func LoadText(ctx context.Context, d templateDependencies, filesystem fs.FS, name, pattern string, model interface{}, remoteURL string) (string, error) { var t Template var err error if remoteURL != "" { @@ -171,7 +171,7 @@ func LoadTextTemplate(ctx context.Context, d templateDependencies, filesystem fs return b.String(), nil } -func LoadHTMLTemplate(ctx context.Context, d templateDependencies, filesystem fs.FS, name, pattern string, model interface{}, remoteURL string) (string, error) { +func LoadHTML(ctx context.Context, d templateDependencies, filesystem fs.FS, name, pattern string, model interface{}, remoteURL string) (string, error) { var t Template var err error if remoteURL != "" { diff --git a/courier/template/load_template_test.go b/courier/template/load_template_test.go index b27ac0f21a01..89b9a936244d 100644 --- a/courier/template/load_template_test.go +++ b/courier/template/load_template_test.go @@ -30,7 +30,7 @@ func TestLoadTextTemplate(t *testing.T) { var executeTextTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string { ctx := context.Background() _, reg := internal.NewFastRegistryWithMocks(t) - tp, err := template.LoadTextTemplate(ctx, reg, os.DirFS(dir), name, pattern, model, "") + tp, err := template.LoadText(ctx, reg, os.DirFS(dir), name, pattern, model, "") require.NoError(t, err) return tp } @@ -38,7 +38,7 @@ func TestLoadTextTemplate(t *testing.T) { var executeHTMLTemplate = func(t *testing.T, dir, name, pattern string, model map[string]interface{}) string { ctx := context.Background() _, reg := internal.NewFastRegistryWithMocks(t) - tp, err := template.LoadHTMLTemplate(ctx, reg, os.DirFS(dir), name, pattern, model, "") + tp, err := template.LoadHTML(ctx, reg, os.DirFS(dir), name, pattern, model, "") require.NoError(t, err) return tp } @@ -70,7 +70,7 @@ func TestLoadTextTemplate(t *testing.T) { for _, tc := range nonhermetic { t.Run("case=should not support function: "+tc, func(t *testing.T) { - _, err := template.LoadTextTemplate(ctx, reg, x.NewStubFS(tc, []byte(fmt.Sprintf("{{ %s }}", tc))), tc, "", map[string]interface{}{}, "") + _, err := template.LoadText(ctx, reg, x.NewStubFS(tc, []byte(fmt.Sprintf("{{ %s }}", tc))), tc, "", map[string]interface{}{}, "") require.Error(t, err) require.Contains(t, err.Error(), fmt.Sprintf("function \"%s\" not defined", tc)) }) @@ -108,7 +108,7 @@ func TestLoadTextTemplate(t *testing.T) { f, err := ioutil.ReadFile("courier/builtin/templates/test_stub/email.body.html.en_US.gotmpl") require.NoError(t, err) b64 := base64.StdEncoding.EncodeToString(f) - tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", m, "base64://"+b64) + tp, err := template.LoadHTML(ctx, reg, nil, "", "", m, "base64://"+b64) require.NoError(t, err) assert.Contains(t, tp, "lang=en_US") }) @@ -120,7 +120,7 @@ func TestLoadTextTemplate(t *testing.T) { b64 := base64.StdEncoding.EncodeToString(f) - tp, err := template.LoadTextTemplate(ctx, reg, nil, "", "", m, "base64://"+b64) + tp, err := template.LoadText(ctx, reg, nil, "", "", m, "base64://"+b64) require.NoError(t, err) assert.Contains(t, tp, "stub email body something") }) @@ -130,14 +130,14 @@ func TestLoadTextTemplate(t *testing.T) { t.Run("case=file resource", func(t *testing.T) { t.Run("case=html template", func(t *testing.T) { m := map[string]interface{}{"lang": "en_US"} - tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", m, "file://courier/builtin/templates/test_stub/email.body.html.en_US.gotmpl") + tp, err := template.LoadHTML(ctx, reg, nil, "", "", m, "file://courier/builtin/templates/test_stub/email.body.html.en_US.gotmpl") require.NoError(t, err) assert.Contains(t, tp, "lang=en_US") }) t.Run("case=plaintext", func(t *testing.T) { m := map[string]interface{}{"Body": "something"} - tp, err := template.LoadTextTemplate(ctx, reg, nil, "", "", m, "file://courier/builtin/templates/test_stub/email.body.plaintext.gotmpl") + tp, err := template.LoadText(ctx, reg, nil, "", "", m, "file://courier/builtin/templates/test_stub/email.body.plaintext.gotmpl") require.NoError(t, err) assert.Contains(t, tp, "stub email body something") }) @@ -156,14 +156,14 @@ func TestLoadTextTemplate(t *testing.T) { t.Run("case=html template", func(t *testing.T) { m := map[string]interface{}{"lang": "en_US"} - tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", m, ts.URL+"/html") + tp, err := template.LoadHTML(ctx, reg, nil, "", "", m, ts.URL+"/html") require.NoError(t, err) assert.Contains(t, tp, "lang=en_US") }) t.Run("case=plaintext", func(t *testing.T) { m := map[string]interface{}{"Body": "something"} - tp, err := template.LoadTextTemplate(ctx, reg, nil, "", "", m, ts.URL+"/plaintext") + tp, err := template.LoadText(ctx, reg, nil, "", "", m, ts.URL+"/plaintext") require.NoError(t, err) assert.Contains(t, tp, "stub email body something") }) @@ -171,12 +171,12 @@ func TestLoadTextTemplate(t *testing.T) { }) t.Run("case=unsupported resource", func(t *testing.T) { - tp, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "grpc://unsupported-url") + tp, err := template.LoadHTML(ctx, reg, nil, "", "", map[string]interface{}{}, "grpc://unsupported-url") require.ErrorIs(t, err, fetcher.ErrUnknownScheme) require.Empty(t, tp) - tp, err = template.LoadTextTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "grpc://unsupported-url") + tp, err = template.LoadText(ctx, reg, nil, "", "", map[string]interface{}{}, "grpc://unsupported-url") require.ErrorIs(t, err, fetcher.ErrUnknownScheme) require.Empty(t, tp) }) @@ -186,22 +186,22 @@ func TestLoadTextTemplate(t *testing.T) { reg.HTTPClient(ctx).RetryMax = 1 reg.HTTPClient(ctx).RetryWaitMax = time.Millisecond - _, err := template.LoadHTMLTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "http://localhost:8080/1234") + _, err := template.LoadHTML(ctx, reg, nil, "", "", map[string]interface{}{}, "http://localhost:8080/1234") require.Error(t, err) assert.Contains(t, err.Error(), "is in the") - _, err = template.LoadTextTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "http://localhost:8080/1234") + _, err = template.LoadText(ctx, reg, nil, "", "", map[string]interface{}{}, "http://localhost:8080/1234") require.Error(t, err) assert.Contains(t, err.Error(), "is in the") }) t.Run("method=cache works", func(t *testing.T) { - tp1, err := template.LoadTextTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "base64://e3sgJGwgOj0gY2F0ICJsYW5nPSIgLmxhbmcgfX0Ke3sgbm9zcGFjZSAkbCB9fQ==") + tp1, err := template.LoadText(ctx, reg, nil, "", "", map[string]interface{}{}, "base64://e3sgJGwgOj0gY2F0ICJsYW5nPSIgLmxhbmcgfX0Ke3sgbm9zcGFjZSAkbCB9fQ==") assert.NoError(t, err) - tp2, err := template.LoadTextTemplate(ctx, reg, nil, "", "", map[string]interface{}{}, "base64://c3R1YiBlbWFpbCBib2R5IHt7IC5Cb2R5IH19") + tp2, err := template.LoadText(ctx, reg, nil, "", "", map[string]interface{}{}, "base64://c3R1YiBlbWFpbCBib2R5IHt7IC5Cb2R5IH19") assert.NoError(t, err) require.NotEqualf(t, tp1, tp2, "Expected remote template 1 and remote template 2 to not be equal") diff --git a/courier/template/recovery_invalid.go b/courier/template/recovery_invalid.go deleted file mode 100644 index a90995e3364e..000000000000 --- a/courier/template/recovery_invalid.go +++ /dev/null @@ -1,41 +0,0 @@ -package template - -import ( - "context" - "encoding/json" - "os" -) - -type ( - RecoveryInvalid struct { - d TemplateDependencies - m *RecoveryInvalidModel - } - RecoveryInvalidModel struct { - To string - } -) - -func NewRecoveryInvalid(d TemplateDependencies, m *RecoveryInvalidModel) *RecoveryInvalid { - return &RecoveryInvalid{d: d, m: m} -} - -func (t *RecoveryInvalid) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *RecoveryInvalid) EmailSubject(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.subject.gotmpl", "recovery/invalid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Subject) -} - -func (t *RecoveryInvalid) EmailBody(ctx context.Context) (string, error) { - return LoadHTMLTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.body.gotmpl", "recovery/invalid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Body.HTML) -} - -func (t *RecoveryInvalid) EmailBodyPlaintext(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/invalid/email.body.plaintext.gotmpl", "recovery/invalid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryInvalid().Body.PlainText) -} - -func (t *RecoveryInvalid) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/courier/template/recovery_valid.go b/courier/template/recovery_valid.go deleted file mode 100644 index ba7d1c0fe187..000000000000 --- a/courier/template/recovery_valid.go +++ /dev/null @@ -1,43 +0,0 @@ -package template - -import ( - "context" - "encoding/json" - "os" -) - -type ( - RecoveryValid struct { - d TemplateDependencies - m *RecoveryValidModel - } - RecoveryValidModel struct { - To string - RecoveryURL string - Identity map[string]interface{} - } -) - -func NewRecoveryValid(d TemplateDependencies, m *RecoveryValidModel) *RecoveryValid { - return &RecoveryValid{d: d, m: m} -} - -func (t *RecoveryValid) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *RecoveryValid) EmailSubject(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.subject.gotmpl", "recovery/valid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Subject) -} - -func (t *RecoveryValid) EmailBody(ctx context.Context) (string, error) { - return LoadHTMLTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.body.gotmpl", "recovery/valid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Body.HTML) -} - -func (t *RecoveryValid) EmailBodyPlaintext(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "recovery/valid/email.body.plaintext.gotmpl", "recovery/valid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesRecoveryValid().Body.PlainText) -} - -func (t *RecoveryValid) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/courier/template/sms/otp.go b/courier/template/sms/otp.go new file mode 100644 index 000000000000..ef003f63b0ea --- /dev/null +++ b/courier/template/sms/otp.go @@ -0,0 +1,38 @@ +package sms + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + OTPMessage struct { + d template.Dependencies + m *OTPMessageModel + } + + OTPMessageModel struct { + To string + Code string + Identity map[string]interface{} + } +) + +func NewOTPMessage(d template.Dependencies, m *OTPMessageModel) *OTPMessage { + return &OTPMessage{d: d, m: m} +} + +func (t *OTPMessage) PhoneNumber() (string, error) { + return t.m.To, nil +} + +func (t *OTPMessage) SMSBody(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "otp/sms.body.gotmpl", "otp/sms.body*", t.m, "") +} + +func (t *OTPMessage) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/sms/otp_test.go b/courier/template/sms/otp_test.go new file mode 100644 index 000000000000..01f4dcbbacb9 --- /dev/null +++ b/courier/template/sms/otp_test.go @@ -0,0 +1,34 @@ +package sms_test + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/internal" +) + +func TestNewOTPMessage(t *testing.T) { + _, reg := internal.NewFastRegistryWithMocks(t) + + const ( + expectedPhone = "+12345678901" + otp = "012345" + ) + + tpl := sms.NewOTPMessage(reg, &sms.OTPMessageModel{To: expectedPhone, Code: otp}) + + expectedBody := fmt.Sprintf("Your verification code is: %s\n", otp) + + actualBody, err := tpl.SMSBody(context.Background()) + require.NoError(t, err) + assert.Equal(t, expectedBody, actualBody) + + actualPhone, err := tpl.PhoneNumber() + require.NoError(t, err) + assert.Equal(t, expectedPhone, actualPhone) +} diff --git a/courier/template/sms/stub.go b/courier/template/sms/stub.go new file mode 100644 index 000000000000..fa2fb19e3b5f --- /dev/null +++ b/courier/template/sms/stub.go @@ -0,0 +1,38 @@ +package sms + +import ( + "context" + "encoding/json" + "os" + + "github.com/ory/kratos/courier/template" +) + +type ( + TestStub struct { + d template.Dependencies + m *TestStubModel + } + + TestStubModel struct { + To string + Body string + Identity map[string]interface{} + } +) + +func NewTestStub(d template.Dependencies, m *TestStubModel) *TestStub { + return &TestStub{d: d, m: m} +} + +func (t *TestStub) PhoneNumber() (string, error) { + return t.m.To, nil +} + +func (t *TestStub) SMSBody(ctx context.Context) (string, error) { + return template.LoadText(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "otp/test_stub/sms.body.gotmpl", "otp/test_stub/sms.body*", t.m, "") +} + +func (t *TestStub) MarshalJSON() ([]byte, error) { + return json.Marshal(t.m) +} diff --git a/courier/template/sms/stub_test.go b/courier/template/sms/stub_test.go new file mode 100644 index 000000000000..9b170a5532e4 --- /dev/null +++ b/courier/template/sms/stub_test.go @@ -0,0 +1,31 @@ +package sms_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/courier/template/sms" + "github.com/ory/kratos/internal" +) + +func TestNewTestStub(t *testing.T) { + _, reg := internal.NewFastRegistryWithMocks(t) + + const ( + expectedPhone = "+12345678901" + expectedBody = "test sms" + ) + + tpl := sms.NewTestStub(reg, &sms.TestStubModel{To: expectedPhone, Body: expectedBody}) + + actualBody, err := tpl.SMSBody(context.Background()) + require.NoError(t, err) + assert.Equal(t, "stub sms body test sms\n", actualBody) + + actualPhone, err := tpl.PhoneNumber() + require.NoError(t, err) + assert.Equal(t, expectedPhone, actualPhone) +} diff --git a/courier/template/stub.go b/courier/template/stub.go deleted file mode 100644 index 58f95a376696..000000000000 --- a/courier/template/stub.go +++ /dev/null @@ -1,42 +0,0 @@ -package template - -import ( - "context" - "encoding/json" - "os" -) - -type TestStub struct { - d TemplateDependencies - m *TestStubModel -} - -type TestStubModel struct { - To string - Subject string - Body string -} - -func NewTestStub(d TemplateDependencies, m *TestStubModel) *TestStub { - return &TestStub{d: d, m: m} -} - -func (t *TestStub) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *TestStub) EmailSubject(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.subject.gotmpl", "test_stub/email.subject*", t.m, "") -} - -func (t *TestStub) EmailBody(ctx context.Context) (string, error) { - return LoadHTMLTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.body.gotmpl", "test_stub/email.body*", t.m, "") -} - -func (t *TestStub) EmailBodyPlaintext(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "test_stub/email.body.plaintext.gotmpl", "test_stub/email.body.plaintext*", t.m, "") -} - -func (t *TestStub) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/courier/template/template.go b/courier/template/template.go index 46e94b8e9c60..f81e8ce444aa 100644 --- a/courier/template/template.go +++ b/courier/template/template.go @@ -10,14 +10,15 @@ import ( ) type ( - TemplateConfig interface { + Config interface { CourierTemplatesRoot() string CourierTemplatesVerificationInvalid() *config.CourierEmailTemplate CourierTemplatesVerificationValid() *config.CourierEmailTemplate CourierTemplatesRecoveryInvalid() *config.CourierEmailTemplate CourierTemplatesRecoveryValid() *config.CourierEmailTemplate } - TemplateDependencies interface { + + Dependencies interface { CourierConfig(ctx context.Context) config.CourierConfigs HTTPClient(ctx context.Context, opts ...httpx.ResilientOptions) *retryablehttp.Client } diff --git a/courier/template/testhelpers/testhelpers.go b/courier/template/testhelpers/testhelpers.go index 0e2a3a490827..895ec767f408 100644 --- a/courier/template/testhelpers/testhelpers.go +++ b/courier/template/testhelpers/testhelpers.go @@ -9,6 +9,8 @@ import ( "path" "testing" + "github.com/ory/kratos/courier/template/email" + "github.com/julienschmidt/httprouter" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -55,21 +57,21 @@ func TestRemoteTemplates(t *testing.T, basePath string, tmplType courier.Templat return base64.StdEncoding.EncodeToString(f) } - getTemplate := func(tmpl courier.TemplateType, d template.TemplateDependencies) interface { + getTemplate := func(tmpl courier.TemplateType, d template.Dependencies) interface { EmailBody(context.Context) (string, error) EmailSubject(context.Context) (string, error) } { switch tmpl { case courier.TypeRecoveryInvalid: - return template.NewRecoveryInvalid(d, &template.RecoveryInvalidModel{}) + return email.NewRecoveryInvalid(d, &email.RecoveryInvalidModel{}) case courier.TypeRecoveryValid: - return template.NewRecoveryValid(d, &template.RecoveryValidModel{}) + return email.NewRecoveryValid(d, &email.RecoveryValidModel{}) case courier.TypeTestStub: - return template.NewTestStub(d, &template.TestStubModel{}) + return email.NewTestStub(d, &email.TestStubModel{}) case courier.TypeVerificationInvalid: - return template.NewVerificationInvalid(d, &template.VerificationInvalidModel{}) + return email.NewVerificationInvalid(d, &email.VerificationInvalidModel{}) case courier.TypeVerificationValid: - return template.NewVerificationValid(d, &template.VerificationValidModel{}) + return email.NewVerificationValid(d, &email.VerificationValidModel{}) default: return nil } diff --git a/courier/template/verification_invalid.go b/courier/template/verification_invalid.go deleted file mode 100644 index e78ec3a106f1..000000000000 --- a/courier/template/verification_invalid.go +++ /dev/null @@ -1,41 +0,0 @@ -package template - -import ( - "context" - "encoding/json" - "os" -) - -type ( - VerificationInvalid struct { - d TemplateDependencies - m *VerificationInvalidModel - } - VerificationInvalidModel struct { - To string - } -) - -func NewVerificationInvalid(d TemplateDependencies, m *VerificationInvalidModel) *VerificationInvalid { - return &VerificationInvalid{d: d, m: m} -} - -func (t *VerificationInvalid) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *VerificationInvalid) EmailSubject(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.subject.gotmpl", "verification/invalid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Subject) -} - -func (t *VerificationInvalid) EmailBody(ctx context.Context) (string, error) { - return LoadHTMLTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.body.gotmpl", "verification/invalid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Body.HTML) -} - -func (t *VerificationInvalid) EmailBodyPlaintext(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/invalid/email.body.plaintext.gotmpl", "verification/invalid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationInvalid().Body.PlainText) -} - -func (t *VerificationInvalid) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/courier/template/verification_valid.go b/courier/template/verification_valid.go deleted file mode 100644 index cdd6e25c6b85..000000000000 --- a/courier/template/verification_valid.go +++ /dev/null @@ -1,43 +0,0 @@ -package template - -import ( - "context" - "encoding/json" - "os" -) - -type ( - VerificationValid struct { - d TemplateDependencies - m *VerificationValidModel - } - VerificationValidModel struct { - To string - VerificationURL string - Identity map[string]interface{} - } -) - -func NewVerificationValid(d TemplateDependencies, m *VerificationValidModel) *VerificationValid { - return &VerificationValid{d: d, m: m} -} - -func (t *VerificationValid) EmailRecipient() (string, error) { - return t.m.To, nil -} - -func (t *VerificationValid) EmailSubject(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.subject.gotmpl", "verification/valid/email.subject*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Subject) -} - -func (t *VerificationValid) EmailBody(ctx context.Context) (string, error) { - return LoadHTMLTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.body.gotmpl", "verification/valid/email.body*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Body.HTML) -} - -func (t *VerificationValid) EmailBodyPlaintext(ctx context.Context) (string, error) { - return LoadTextTemplate(ctx, t.d, os.DirFS(t.d.CourierConfig(ctx).CourierTemplatesRoot()), "verification/valid/email.body.plaintext.gotmpl", "verification/valid/email.body.plaintext*", t.m, t.d.CourierConfig(ctx).CourierTemplatesVerificationValid().Body.PlainText) -} - -func (t *VerificationValid) MarshalJSON() ([]byte, error) { - return json.Marshal(t.m) -} diff --git a/driver/config/.snapshots/TestCourierSMS-case=configs_set.json b/driver/config/.snapshots/TestCourierSMS-case=configs_set.json new file mode 100644 index 000000000000..471e0019a448 --- /dev/null +++ b/driver/config/.snapshots/TestCourierSMS-case=configs_set.json @@ -0,0 +1,15 @@ +{ + "auth": { + "config": { + "password": "YourPass", + "user": "YourUsername" + }, + "type": "basic_auth" + }, + "body": "base64://e30=", + "header": { + "Content-Type": "application/x-www-form-urlencoded" + }, + "method": "POST", + "url": "https://api.twilio.com/2010-04-01/Accounts/YourAccountID/Messages.json" +} diff --git a/driver/config/.snapshots/TestCourierSMS-case=defaults.json b/driver/config/.snapshots/TestCourierSMS-case=defaults.json new file mode 100644 index 000000000000..19765bd501b6 --- /dev/null +++ b/driver/config/.snapshots/TestCourierSMS-case=defaults.json @@ -0,0 +1 @@ +null diff --git a/driver/config/config.go b/driver/config/config.go index 1f728668b328..588225a632a5 100644 --- a/driver/config/config.go +++ b/driver/config/config.go @@ -69,6 +69,9 @@ const ( ViperKeyCourierSMTPFrom = "courier.smtp.from_address" ViperKeyCourierSMTPFromName = "courier.smtp.from_name" ViperKeyCourierSMTPHeaders = "courier.smtp.headers" + ViperKeyCourierSMSRequestConfig = "courier.sms.request_config" + ViperKeyCourierSMSEnabled = "courier.sms.enabled" + ViperKeyCourierSMSFrom = "courier.sms.from" ViperKeySecretsDefault = "secrets.default" ViperKeySecretsCookie = "secrets.cookie" ViperKeySecretsCipher = "secrets.cipher" @@ -235,6 +238,9 @@ type ( CourierSMTPFrom() string CourierSMTPFromName() string CourierSMTPHeaders() map[string]string + CourierSMSEnabled() bool + CourierSMSFrom() string + CourierSMSRequestConfig() json.RawMessage CourierTemplatesRoot() string CourierTemplatesVerificationInvalid() *CourierEmailTemplate CourierTemplatesVerificationValid() *CourierEmailTemplate @@ -919,6 +925,33 @@ func (p *Config) CourierSMTPHeaders() map[string]string { return p.p.StringMap(ViperKeyCourierSMTPHeaders) } +func (p *Config) CourierSMSRequestConfig() json.RawMessage { + if !p.p.Bool(ViperKeyCourierSMSEnabled) { + return nil + } + + out, err := p.p.Marshal(kjson.Parser()) + if err != nil { + p.l.WithError(err).Warn("Unable to marshal self service strategy configuration.") + return nil + } + + config := gjson.GetBytes(out, ViperKeyCourierSMSRequestConfig).Raw + if len(config) <= 0 { + return json.RawMessage("{}") + } + + return json.RawMessage(config) +} + +func (p *Config) CourierSMSFrom() string { + return p.p.StringF(ViperKeyCourierSMSFrom, "Ory Kratos") +} + +func (p *Config) CourierSMSEnabled() bool { + return p.p.Bool(ViperKeyCourierSMSEnabled) +} + func splitUrlAndFragment(s string) (string, string) { i := strings.IndexByte(s, '#') if i < 0 { diff --git a/driver/config/config_test.go b/driver/config/config_test.go index fb674635fa19..808517557294 100644 --- a/driver/config/config_test.go +++ b/driver/config/config_test.go @@ -16,6 +16,8 @@ import ( "testing" "time" + "github.com/ory/x/snapshotx" + "github.com/ghodss/yaml" "github.com/spf13/cobra" @@ -1046,6 +1048,26 @@ func TestChangeMinPasswordLength(t *testing.T) { }) } +func TestCourierSMS(t *testing.T) { + ctx := context.Background() + + t.Run("case=configs set", func(t *testing.T) { + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, + configx.WithConfigFiles("stub/.kratos.courier.sms.yaml"), configx.SkipValidation()) + assert.True(t, conf.CourierSMSEnabled()) + snapshotx.SnapshotTExcept(t, conf.CourierSMSRequestConfig(), nil) + assert.Equal(t, "+49123456789", conf.CourierSMSFrom()) + }) + + t.Run("case=defaults", func(t *testing.T) { + conf, _ := config.New(ctx, logrusx.New("", ""), os.Stderr, configx.SkipValidation()) + + assert.False(t, conf.CourierSMSEnabled()) + snapshotx.SnapshotTExcept(t, conf.CourierSMSRequestConfig(), nil) + assert.Equal(t, "Ory Kratos", conf.CourierSMSFrom()) + }) +} + func TestCourierTemplatesConfig(t *testing.T) { ctx := context.Background() diff --git a/driver/config/stub/.kratos.courier.sms.yaml b/driver/config/stub/.kratos.courier.sms.yaml new file mode 100644 index 000000000000..1c2fbae89c3a --- /dev/null +++ b/driver/config/stub/.kratos.courier.sms.yaml @@ -0,0 +1,28 @@ +dsn: sqlite://foo.db?mode=memory&_fk=true + +selfservice: + default_browser_return_url: http://return-to-3-test.ory.sh/ + +identity: + default_schema_id: default + schemas: + - id: default + url: base64://ewogICIkaWQiOiAib3J5Oi8vaWRlbnRpdHktdGVzdC1zY2hlbWEiLAogICIkc2NoZW1hIjogImh0dHA6Ly9qc29uLXNjaGVtYS5vcmcvZHJhZnQtMDcvc2NoZW1hIyIsCiAgInRpdGxlIjogIklkZW50aXR5U2NoZW1hIiwKICAidHlwZSI6ICJvYmplY3QiLAogICJwcm9wZXJ0aWVzIjogewogICAgInRyYWl0cyI6IHsKICAgICAgInR5cGUiOiAib2JqZWN0IiwKICAgICAgInByb3BlcnRpZXMiOiB7CiAgICAgICAgIm5hbWUiOiB7CiAgICAgICAgICAidHlwZSI6ICJvYmplY3QiLAogICAgICAgICAgInByb3BlcnRpZXMiOiB7CiAgICAgICAgICAgICJmaXJzdCI6IHsKICAgICAgICAgICAgICAidHlwZSI6ICJzdHJpbmciCiAgICAgICAgICAgIH0sCiAgICAgICAgICAgICJsYXN0IjogewogICAgICAgICAgICAgICJ0eXBlIjogInN0cmluZyIKICAgICAgICAgICAgfQogICAgICAgICAgfQogICAgICAgIH0KICAgICAgfSwKICAgICAgInJlcXVpcmVkIjogWwogICAgICAgICJuYW1lIgogICAgICBdLAogICAgICAiYWRkaXRpb25hbFByb3BlcnRpZXMiOiB0cnVlCiAgICB9CiAgfQp9 + +courier: + smtp: + connection_uri: smtp://foo:bar@baz/ + sms: + enabled: true + from: '+49123456789' + request_config: + url: https://api.twilio.com/2010-04-01/Accounts/YourAccountID/Messages.json + method: POST + body: base64://e30= + header: + 'Content-Type': 'application/x-www-form-urlencoded' + auth: + type: basic_auth + config: + user: YourUsername + password: YourPass diff --git a/driver/registry_default.go b/driver/registry_default.go index 0d1e6fa1d05d..62a089edb233 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -588,8 +588,8 @@ func (m *RegistryDefault) SetPersister(p persistence.Persister) { m.persister = p } -func (m *RegistryDefault) Courier(ctx context.Context) *courier.Courier { - return courier.NewSMTP(ctx, m) +func (m *RegistryDefault) Courier(ctx context.Context) courier.Courier { + return courier.NewCourier(ctx, m) } func (m *RegistryDefault) ContinuityManager() continuity.Manager { diff --git a/embedx/config.schema.json b/embedx/config.schema.json index d0310b7ca866..0f3674dd358a 100644 --- a/embedx/config.schema.json +++ b/embedx/config.schema.json @@ -1445,6 +1445,81 @@ "connection_uri" ], "additionalProperties": false + }, + "sms": { + "title": "SMS sender configuration", + "description": "Configures outgoing sms messages using HTTP protocol with generic SMS provider", + "type": "object", + "properties": { + "enabled": { + "description": "Determines if SMS functionality is enabled", + "type": "boolean", + "default": false + }, + "from": { + "title": "SMS Sender Address", + "description": "The recipient of a sms will see this as the sender address.", + "type": "string", + "default": "Ory Kratos" + }, + "request_config": { + "type": "object", + "properties": { + "url": { + "title": "HTTP address of API endpoint", + "description": "This URL will be used to connect to the SMS provider.", + "examples": [ + "https://api.twillio.com/sms/send" + ], + "type": "string", + "pattern": "^https?:\\/\\/.*" + }, + "method": { + "type": "string", + "description": "The HTTP method to use (GET, POST, etc)." + }, + "header": { + "type": "object", + "description": "The HTTP headers that must be applied to request", + "additionalProperties": { + "type": "string" + } + }, + "body": { + "type": "string", + "format": "uri", + "pattern": "^(http|https|file|base64)://", + "description": "URI pointing to the jsonnet template used for payload generation. Only used for those HTTP methods, which support HTTP body payloads", + "examples": [ + "file:///path/to/body.jsonnet", + "file://./body.jsonnet", + "base64://ZnVuY3Rpb24oY3R4KSB7CiAgaWRlbnRpdHlfaWQ6IGlmIGN0eFsiaWRlbnRpdHkiXSAhPSBudWxsIHRoZW4gY3R4LmlkZW50aXR5LmlkLAp9=", + "https://oryapis.com/default_body.jsonnet" + ] + }, + "auth": { + "type": "object", + "title": "Auth mechanisms", + "description": "Define which auth mechanism to use for auth with the SMS provider", + "oneOf": [ + { + "$ref": "#/definitions/webHookAuthApiKeyProperties" + }, + { + "$ref": "#/definitions/webHookAuthBasicAuthProperties" + } + ] + }, + "additionalProperties": false + }, + "required": [ + "url", + "method" + ], + "additionalProperties": false + } + }, + "additionalProperties": false } }, "required": [ diff --git a/go.mod b/go.mod index 4a3980dbad82..b2e99b81671e 100644 --- a/go.mod +++ b/go.mod @@ -76,7 +76,7 @@ require ( github.com/ory/kratos-client-go v0.6.3-alpha.1 github.com/ory/mail/v3 v3.0.0 github.com/ory/nosurf v1.2.7 - github.com/ory/x v0.0.345 + github.com/ory/x v0.0.348 github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.3.0 diff --git a/go.sum b/go.sum index 66299a3bbd6a..db4e460ee60e 100644 --- a/go.sum +++ b/go.sum @@ -1804,8 +1804,8 @@ github.com/ory/x v0.0.205/go.mod h1:A1s4iwmFIppRXZLF3J9GGWeY/HpREVm0Dk5z/787iek= github.com/ory/x v0.0.250/go.mod h1:jUJaVptu+geeqlb9SyQCogTKj5ztSDIF6APkhbKtwLc= github.com/ory/x v0.0.272/go.mod h1:1TTPgJGQutrhI2OnwdrTIHE9ITSf4MpzXFzA/ncTGRc= github.com/ory/x v0.0.288/go.mod h1:APpShLyJcVzKw1kTgrHI+j/L9YM+8BRjHlcYObc7C1U= -github.com/ory/x v0.0.345 h1:e3ZCt8SxLXQdn/fWM/xjxl+2+DhjrTNIY9DVwYMR2m4= -github.com/ory/x v0.0.345/go.mod h1:Ddbu3ecSaNDgxdntdD1gDu3ALG5fWR5AwUB1ILeBUNE= +github.com/ory/x v0.0.348 h1:Z2wbEvSpTindtjKTTrd3grIlWbBtvW2udYG5ZjTZHTo= +github.com/ory/x v0.0.348/go.mod h1:Ddbu3ecSaNDgxdntdD1gDu3ALG5fWR5AwUB1ILeBUNE= github.com/otiai10/copy v1.2.0/go.mod h1:rrF5dJ5F0t/EWSYODDu4j9/vEeYHMkc8jt0zJChqQWw= github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE= github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= diff --git a/request/auth.go b/request/auth.go new file mode 100644 index 000000000000..398aa0aef910 --- /dev/null +++ b/request/auth.go @@ -0,0 +1,31 @@ +package request + +import ( + "encoding/json" + "fmt" + + "github.com/hashicorp/go-retryablehttp" +) + +type ( + AuthStrategy interface { + apply(req *retryablehttp.Request) + } + + authStrategyFactory func(c json.RawMessage) (AuthStrategy, error) +) + +var strategyFactories = map[string]authStrategyFactory{ + "": newNoopAuthStrategy, + "api_key": newApiKeyStrategy, + "basic_auth": newBasicAuthStrategy, +} + +func authStrategy(name string, config json.RawMessage) (AuthStrategy, error) { + strategyFactory, ok := strategyFactories[name] + if ok { + return strategyFactory(config) + } + + return nil, fmt.Errorf("unsupported auth type: %s", name) +} diff --git a/request/auth_strategy.go b/request/auth_strategy.go new file mode 100644 index 000000000000..f280a5b92e70 --- /dev/null +++ b/request/auth_strategy.go @@ -0,0 +1,78 @@ +package request + +import ( + "encoding/json" + "net/http" + + "github.com/hashicorp/go-retryablehttp" +) + +type ( + noopAuthStrategy struct{} + + basicAuthStrategy struct { + user string + password string + } + + apiKeyStrategy struct { + name string + value string + in string + } +) + +func newNoopAuthStrategy(_ json.RawMessage) (AuthStrategy, error) { + return &noopAuthStrategy{}, nil +} + +func (c *noopAuthStrategy) apply(_ *retryablehttp.Request) {} + +func newBasicAuthStrategy(raw json.RawMessage) (AuthStrategy, error) { + type config struct { + User string + Password string + } + + var c config + if err := json.Unmarshal(raw, &c); err != nil { + return nil, err + } + + return &basicAuthStrategy{ + user: c.User, + password: c.Password, + }, nil +} + +func (c *basicAuthStrategy) apply(req *retryablehttp.Request) { + req.SetBasicAuth(c.user, c.password) +} + +func newApiKeyStrategy(raw json.RawMessage) (AuthStrategy, error) { + type config struct { + In string + Name string + Value string + } + + var c config + if err := json.Unmarshal(raw, &c); err != nil { + return nil, err + } + + return &apiKeyStrategy{ + in: c.In, + name: c.Name, + value: c.Value, + }, nil +} + +func (c *apiKeyStrategy) apply(req *retryablehttp.Request) { + switch c.in { + case "cookie": + req.AddCookie(&http.Cookie{Name: c.name, Value: c.value}) + default: + req.Header.Set(c.name, c.value) + } +} diff --git a/request/auth_strategy_test.go b/request/auth_strategy_test.go new file mode 100644 index 000000000000..e2422fb40425 --- /dev/null +++ b/request/auth_strategy_test.go @@ -0,0 +1,69 @@ +package request + +import ( + "net/http" + "testing" + + "github.com/hashicorp/go-retryablehttp" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNoopAuthStrategy(t *testing.T) { + req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} + auth := noopAuthStrategy{} + + auth.apply(&req) + + assert.Empty(t, req.Header, "Empty auth strategy shall not modify any request headers") +} + +func TestBasicAuthStrategy(t *testing.T) { + req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} + auth := basicAuthStrategy{ + user: "test-user", + password: "test-pass", + } + + auth.apply(&req) + + assert.Len(t, req.Header, 1) + + user, pass, _ := req.BasicAuth() + assert.Equal(t, "test-user", user) + assert.Equal(t, "test-pass", pass) +} + +func TestApiKeyInHeaderStrategy(t *testing.T) { + req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} + auth := apiKeyStrategy{ + in: "header", + name: "my-api-key-name", + value: "my-api-key-value", + } + + auth.apply(&req) + + require.Len(t, req.Header, 1) + + actualValue := req.Header.Get("my-api-key-name") + assert.Equal(t, "my-api-key-value", actualValue) +} + +func TestApiKeyInCookieStrategy(t *testing.T) { + req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} + auth := apiKeyStrategy{ + in: "cookie", + name: "my-api-key-name", + value: "my-api-key-value", + } + + auth.apply(&req) + + cookies := req.Cookies() + assert.Len(t, cookies, 1) + + assert.Equal(t, "my-api-key-name", cookies[0].Name) + assert.Equal(t, "my-api-key-value", cookies[0].Value) +} diff --git a/request/auth_test.go b/request/auth_test.go new file mode 100644 index 000000000000..c0df79336905 --- /dev/null +++ b/request/auth_test.go @@ -0,0 +1,56 @@ +package request + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthStrategy(t *testing.T) { + for _, tc := range map[string]struct { + name string + config string + expected AuthStrategy + }{ + "noop": { + name: "", + config: "", + expected: &noopAuthStrategy{}, + }, + "basic_auth": { + name: "basic_auth", + config: `{ + "user": "test-api-user", + "password": "secret" + }`, + expected: &basicAuthStrategy{}, + }, + "api-key/header": { + name: "api_key", + config: `{ + "in": "header", + "name": "my-api-key", + "value": "secret" + }`, + expected: &apiKeyStrategy{}, + }, + "api-key/cookie": { + name: "api_key", + config: `{ + "in": "cookie", + "name": "my-api-key", + "value": "secret" + }`, + expected: &apiKeyStrategy{}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + strategy, err := authStrategy(tc.name, json.RawMessage(tc.config)) + require.NoError(t, err) + + assert.IsTypef(t, tc.expected, strategy, "auth strategy should be of the expected type") + }) + } +} diff --git a/request/builder.go b/request/builder.go new file mode 100644 index 000000000000..fbd1782f2936 --- /dev/null +++ b/request/builder.go @@ -0,0 +1,200 @@ +package request + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "net/url" + "reflect" + "strings" + + "github.com/google/go-jsonnet" + "github.com/hashicorp/go-retryablehttp" + + "github.com/ory/x/fetcher" + "github.com/ory/x/logrusx" +) + +const ( + ContentTypeForm = "application/x-www-form-urlencoded" + ContentTypeJSON = "application/json" +) + +type Builder struct { + r *retryablehttp.Request + log *logrusx.Logger + conf *Config + fetchClient *retryablehttp.Client +} + +func NewBuilder(config json.RawMessage, client *retryablehttp.Client, l *logrusx.Logger) (*Builder, error) { + c, err := parseConfig(config) + if err != nil { + return nil, err + } + + r, err := retryablehttp.NewRequest(c.Method, c.URL, nil) + if err != nil { + return nil, err + } + + return &Builder{ + r: r, + log: l, + conf: c, + fetchClient: client, + }, nil +} + +func (b *Builder) addAuth() error { + authConfig := b.conf.Auth + + strategy, err := authStrategy(authConfig.Type, authConfig.Config) + if err != nil { + return err + } + + strategy.apply(b.r) + + return nil +} + +func (b *Builder) addBody(body interface{}) error { + if isNilInterface(body) { + return nil + } + + contentType := b.r.Header.Get("Content-Type") + + if b.conf.TemplateURI == "" { + return errors.New("got empty template path for request with body") + } + + tpl, err := b.readTemplate() + if err != nil { + return err + } + + switch contentType { + case ContentTypeForm: + if err := b.addURLEncodedBody(tpl, body); err != nil { + return err + } + case ContentTypeJSON: + if err := b.addJSONBody(tpl, body); err != nil { + return err + } + default: + return errors.New("invalid config - incorrect Content-Type for request with body") + } + + return nil +} + +func (b *Builder) addJSONBody(template *bytes.Buffer, body interface{}) error { + buf := new(bytes.Buffer) + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(false) + enc.SetIndent("", "") + + if err := enc.Encode(body); err != nil { + return err + } + + vm := jsonnet.MakeVM() + vm.TLACode("ctx", buf.String()) + + res, err := vm.EvaluateAnonymousSnippet(b.conf.TemplateURI, template.String()) + if err != nil { + return err + } + + rb := strings.NewReader(res) + b.r.Body = io.NopCloser(rb) + b.r.ContentLength = int64(rb.Len()) + + return nil +} + +func (b *Builder) addURLEncodedBody(template *bytes.Buffer, body interface{}) error { + buf := new(bytes.Buffer) + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(false) + enc.SetIndent("", "") + + if err := enc.Encode(body); err != nil { + return err + } + + vm := jsonnet.MakeVM() + vm.TLACode("ctx", buf.String()) + + res, err := vm.EvaluateAnonymousSnippet(b.conf.TemplateURI, template.String()) + if err != nil { + return err + } + + values := map[string]string{} + if err := json.Unmarshal([]byte(res), &values); err != nil { + return err + } + + u := url.Values{} + + for key, value := range values { + u.Add(key, value) + } + + rb := strings.NewReader(u.Encode()) + b.r.Body = io.NopCloser(rb) + + return nil +} + +func (b *Builder) BuildRequest(body interface{}) (*retryablehttp.Request, error) { + b.r.Header = b.conf.Header + if err := b.addAuth(); err != nil { + return nil, err + } + + // According to the HTTP spec any request method, but TRACE is allowed to + // have a body. Even this is a bad practice for some of them, like for GET + if b.conf.Method != http.MethodTrace { + if err := b.addBody(body); err != nil { + return nil, err + } + } + + return b.r, nil +} + +func (b *Builder) readTemplate() (*bytes.Buffer, error) { + templateURI := b.conf.TemplateURI + + if templateURI == "" { + return nil, nil + } + + f := fetcher.NewFetcher(fetcher.WithClient(b.fetchClient)) + + tpl, err := f.Fetch(templateURI) + if errors.Is(err, fetcher.ErrUnknownScheme) { + // legacy filepath + templateURI = "file://" + templateURI + b.log.WithError(err).Warnf("support for filepaths without a 'file://' scheme will be dropped in the next release, please use %s instead in your config", templateURI) + + tpl, err = f.Fetch(templateURI) + } + // this handles the first error if it is a known scheme error, or the second fetch error + if err != nil { + return nil, err + } + + return tpl, nil +} + +func isNilInterface(i interface{}) bool { + return i == nil || (reflect.ValueOf(i).Kind() == reflect.Ptr && reflect.ValueOf(i).IsNil()) +} diff --git a/request/builder_test.go b/request/builder_test.go new file mode 100644 index 000000000000..ad868392bf2d --- /dev/null +++ b/request/builder_test.go @@ -0,0 +1,272 @@ +package request + +import ( + _ "embed" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/x/logrusx" +) + +type testRequestBody struct { + To string + From string + Body string +} + +//go:embed stub/test_body.jsonnet +var testJSONNetTemplate []byte + +func TestBuildRequest(t *testing.T) { + for _, tc := range []struct { + name string + method string + url string + authStrategy string + expectedHeader http.Header + bodyTemplateURI string + body *testRequestBody + expectedBody string + rawConfig string + }{ + { + name: "POST request without auth", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint1", + authStrategy: "", // noop strategy + bodyTemplateURI: "file://./stub/test_body.jsonnet", + body: &testRequestBody{ + To: "+15056445993", + From: "+12288534869", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+12288534869\",\n \"To\": \"+15056445993\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint1", + "method": "POST", + "body": "file://./stub/test_body.jsonnet" + }`, + }, + { + name: "POST request with legacy template path", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint1", + bodyTemplateURI: "./stub/test_body.jsonnet", + body: &testRequestBody{ + To: "+15056445993", + From: "+12288534869", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+12288534869\",\n \"To\": \"+15056445993\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint1", + "method": "POST", + "body": "./stub/test_body.jsonnet" + }`, + }, + { + name: "POST request with base64 encoded template path", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint1", + bodyTemplateURI: "base64://" + base64.StdEncoding.EncodeToString(testJSONNetTemplate), + body: &testRequestBody{ + To: "+15056445993", + From: "+12288534869", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+12288534869\",\n \"To\": \"+15056445993\"\n}\n", + rawConfig: fmt.Sprintf(`{ + "url": "https://test.kratos.ory.sh/my_endpoint1", + "method": "POST", + "body": "base64://%s" + }`, base64.StdEncoding.EncodeToString(testJSONNetTemplate)), + }, + { + name: "POST request with custom header", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint2", + authStrategy: "", + expectedHeader: map[string][]string{"Custom-Header": {"test"}}, + bodyTemplateURI: "file://./stub/test_body.jsonnet", + body: &testRequestBody{ + To: "+12127110378", + From: "+15822228108", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+15822228108\",\n \"To\": \"+12127110378\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint2", + "method": "POST", + "header": { + "Custom-Header": "test" + }, + "body": "file://./stub/test_body.jsonnet" + }`, + }, + { + name: "GET request with body", + method: "GET", + url: "https://test.kratos.ory.sh/my_endpoint3", + authStrategy: "basic_auth", + bodyTemplateURI: "file://./stub/test_body.jsonnet", + body: &testRequestBody{ + To: "+14134242223", + From: "+13104661805", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+13104661805\",\n \"To\": \"+14134242223\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint3", + "method": "GET", + "auth": { + "type": "basic_auth", + "config": { + "user": "test-api-user", + "password": "secret" + } + }, + "body": "file://./stub/test_body.jsonnet" + }`, + }, + { + name: "GET request without body", + method: "GET", + url: "https://test.kratos.ory.sh/my_endpoint4", + authStrategy: "basic_auth", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint4", + "method": "GET", + "auth": { + "type": "basic_auth", + "config": { + "user": "test-api-user", + "password": "secret" + } + } + }`, + }, + { + name: "DELETE request with body", + method: "DELETE", + url: "https://test.kratos.ory.sh/my_endpoint5", + authStrategy: "api_key", + bodyTemplateURI: "file://./stub/test_body.jsonnet", + body: &testRequestBody{ + To: "+12235499085", + From: "+14253787846", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+14253787846\",\n \"To\": \"+12235499085\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint5", + "method": "DELETE", + "body": "file://./stub/test_body.jsonnet", + "auth": { + "type": "api_key", + "config": { + "in": "header", + "name": "my-api-key", + "value": "secret" + } + } + }`, + }, + { + name: "POST request with urlencoded body", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint6", + bodyTemplateURI: "file://./stub/test_body.jsonnet", + authStrategy: "api_key", + expectedHeader: map[string][]string{"Content-Type": {ContentTypeForm}}, + body: &testRequestBody{ + To: "+14134242223", + From: "+13104661805", + Body: "test-sms-body", + }, + expectedBody: "Body=test-sms-body&From=%2B13104661805&To=%2B14134242223", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint6", + "method": "POST", + "body": "file://./stub/test_body.jsonnet", + "header": { + "Content-Type": "application/x-www-form-urlencoded" + }, + "auth": { + "type": "api_key", + "config": { + "in": "cookie", + "name": "my-api-key", + "value": "secret" + } + } + }`, + }, + { + name: "POST request with default body type", + method: "POST", + url: "https://test.kratos.ory.sh/my_endpoint7", + bodyTemplateURI: "file://./stub/test_body.jsonnet", + authStrategy: "basic_auth", + expectedHeader: map[string][]string{"Content-Type": {ContentTypeJSON}}, + body: &testRequestBody{ + To: "+14134242223", + From: "+13104661805", + Body: "test-sms-body", + }, + expectedBody: "{\n \"Body\": \"test-sms-body\",\n \"From\": \"+13104661805\",\n \"To\": \"+14134242223\"\n}\n", + rawConfig: `{ + "url": "https://test.kratos.ory.sh/my_endpoint7", + "method": "POST", + "body": "file://./stub/test_body.jsonnet", + "auth": { + "type": "basic_auth", + "config": { + "user": "test-api-user", + "password": "secret" + } + } + }`, + }, + } { + t.Run("request-type="+tc.name, func(t *testing.T) { + l := logrusx.New("kratos", "test") + + rb, err := NewBuilder(json.RawMessage(tc.rawConfig), nil, l) + require.NoError(t, err) + + assert.Equal(t, tc.bodyTemplateURI, rb.conf.TemplateURI) + assert.Equal(t, tc.authStrategy, rb.conf.Auth.Type) + + req, err := rb.BuildRequest(tc.body) + require.NoError(t, err) + + assert.Equal(t, tc.url, req.URL.String()) + assert.Equal(t, tc.method, req.Method) + + if tc.body != nil { + requestBody, err := ioutil.ReadAll(req.Body) + require.NoError(t, err) + + assert.Equal(t, tc.expectedBody, string(requestBody)) + } + + if tc.expectedHeader != nil { + mustContainHeader(t, tc.expectedHeader, req.Header) + } + }) + } +} + +func mustContainHeader(t *testing.T, expected http.Header, actual http.Header) { + for k := range expected { + require.Contains(t, actual, k) + assert.Equal(t, expected[k], actual[k]) + } +} diff --git a/request/config.go b/request/config.go new file mode 100644 index 000000000000..caf5061bf326 --- /dev/null +++ b/request/config.go @@ -0,0 +1,61 @@ +package request + +import ( + "encoding/json" + "net/http" + + "github.com/tidwall/gjson" +) + +type ( + Auth struct { + Type string + Config json.RawMessage + } + + Config struct { + Method string `json:"method"` + URL string `json:"url"` + TemplateURI string `json:"body"` + Header http.Header `json:"header"` + Auth Auth `json:"auth,omitempty"` + } +) + +func parseConfig(r json.RawMessage) (*Config, error) { + type rawConfig struct { + Method string `json:"method"` + URL string `json:"url"` + TemplateURI string `json:"body"` + Header json.RawMessage `json:"header"` + Auth Auth `json:"auth,omitempty"` + } + + var rc rawConfig + err := json.Unmarshal(r, &rc) + if err != nil { + return nil, err + } + + rawHeader := gjson.ParseBytes(rc.Header).Map() + hdr := http.Header{} + + _, ok := rawHeader["Content-Type"] + if !ok { + hdr.Set("Content-Type", ContentTypeJSON) + } + + for key, value := range rawHeader { + hdr.Set(key, value.String()) + } + + c := Config{ + Method: rc.Method, + URL: rc.URL, + TemplateURI: rc.TemplateURI, + Header: hdr, + Auth: rc.Auth, + } + + return &c, nil +} diff --git a/request/stub/test_body.jsonnet b/request/stub/test_body.jsonnet new file mode 100644 index 000000000000..03edc83a65e2 --- /dev/null +++ b/request/stub/test_body.jsonnet @@ -0,0 +1,5 @@ +function(ctx) { + From: ctx.From, + To: ctx.To, + Body: ctx.Body, +} diff --git a/selfservice/hook/web_hook.go b/selfservice/hook/web_hook.go index 03476a8ac726..bb427ee74653 100644 --- a/selfservice/hook/web_hook.go +++ b/selfservice/hook/web_hook.go @@ -1,22 +1,13 @@ package hook import ( - "bytes" "context" "encoding/json" "fmt" - "io" "net/http" - "github.com/hashicorp/go-retryablehttp" - - "github.com/ory/x/fetcher" - "github.com/ory/x/logrusx" - - "github.com/google/go-jsonnet" - "github.com/pkg/errors" - "github.com/ory/kratos/identity" + "github.com/ory/kratos/request" "github.com/ory/kratos/selfservice/flow" "github.com/ory/kratos/selfservice/flow/login" "github.com/ory/kratos/selfservice/flow/recovery" @@ -32,32 +23,6 @@ var _ verification.PostHookExecutor = new(WebHook) var _ recovery.PostHookExecutor = new(WebHook) type ( - AuthStrategy interface { - apply(req *retryablehttp.Request) - } - - authStrategyFactory func(c json.RawMessage) (AuthStrategy, error) - - NoopAuthStrategy struct{} - - BasicAuthStrategy struct { - user string - password string - } - - ApiKeyStrategy struct { - name string - value string - in string - } - - WebHookConfig struct { - Method string - URL string - TemplateURI string - Auth AuthStrategy - } - webHookDependencies interface { x.LoggingProvider x.HTTPClientProvider @@ -72,113 +37,13 @@ type ( } WebHook struct { - r webHookDependencies - c json.RawMessage + deps webHookDependencies + conf json.RawMessage } ) -var strategyFactories = map[string]authStrategyFactory{ - "": newNoopAuthStrategy, - "api_key": newApiKeyStrategy, - "basic_auth": newBasicAuthStrategy, -} - -func newAuthStrategy(name string, c json.RawMessage) (as AuthStrategy, err error) { - if f, ok := strategyFactories[name]; ok { - as, err = f(c) - } else { - err = fmt.Errorf("unsupported auth type: %s", name) - } - return -} - -func newNoopAuthStrategy(_ json.RawMessage) (AuthStrategy, error) { - return &NoopAuthStrategy{}, nil -} - -func (c *NoopAuthStrategy) apply(_ *retryablehttp.Request) {} - -func newBasicAuthStrategy(raw json.RawMessage) (AuthStrategy, error) { - type config struct { - User string - Password string - } - - var c config - if err := json.Unmarshal(raw, &c); err != nil { - return nil, err - } - - return &BasicAuthStrategy{ - user: c.User, - password: c.Password, - }, nil -} - -func (c *BasicAuthStrategy) apply(req *retryablehttp.Request) { - req.SetBasicAuth(c.user, c.password) -} - -func newApiKeyStrategy(raw json.RawMessage) (AuthStrategy, error) { - type config struct { - In string - Name string - Value string - } - - var c config - if err := json.Unmarshal(raw, &c); err != nil { - return nil, err - } - - return &ApiKeyStrategy{ - in: c.In, - name: c.Name, - value: c.Value, - }, nil -} - -func (c *ApiKeyStrategy) apply(req *retryablehttp.Request) { - switch c.in { - case "cookie": - req.AddCookie(&http.Cookie{Name: c.name, Value: c.value}) - default: - req.Header.Set(c.name, c.value) - } -} - -func newWebHookConfig(r json.RawMessage) (*WebHookConfig, error) { - type rawWebHookConfig struct { - Method string - Url string - Body string - Auth struct { - Type string - Config json.RawMessage - } - } - - var rc rawWebHookConfig - err := json.Unmarshal(r, &rc) - if err != nil { - return nil, err - } - - as, err := newAuthStrategy(rc.Auth.Type, rc.Auth.Config) - if err != nil { - return nil, fmt.Errorf("failed to create web hook auth strategy: %w", err) - } - - return &WebHookConfig{ - Method: rc.Method, - URL: rc.Url, - TemplateURI: rc.Body, - Auth: as, - }, nil -} - func NewWebHook(r webHookDependencies, c json.RawMessage) *WebHook { - return &WebHook{r: r, c: c} + return &WebHook{deps: r, conf: c} } func (e *WebHook) ExecuteLoginPreHook(_ http.ResponseWriter, req *http.Request, flow *login.Flow) error { @@ -250,88 +115,22 @@ func (e *WebHook) ExecuteSettingsPostPersistHook(_ http.ResponseWriter, req *htt } func (e *WebHook) execute(ctx context.Context, data *templateContext) error { - httpClient := e.r.HTTPClient(ctx) - - // TODO: reminder for the future: move parsing of config to the web hook initialization - conf, err := newWebHookConfig(e.c) + builder, err := request.NewBuilder(e.conf, e.deps.HTTPClient(ctx), e.deps.Logger()) if err != nil { - return fmt.Errorf("failed to parse web hook config: %w", err) - } - - var body io.Reader - if conf.Method != "TRACE" { - // According to the HTTP spec any request method, but TRACE is allowed to - // have a body. Even this is a really bad practice for some of them, like for - // GET - body, err = createBody(e.r.Logger(), conf.TemplateURI, data, httpClient) - if err != nil { - return fmt.Errorf("failed to create web hook body: %w", err) - } - } - - if body == nil { - body = bytes.NewReader(make([]byte, 0)) - } - - if err = doHttpCall(ctx, conf.Method, conf.URL, conf.Auth, body, httpClient); err != nil { - return fmt.Errorf("failed to call web hook %w", err) - } - return nil -} - -func createBody(l *logrusx.Logger, templateURI string, data *templateContext, hc *retryablehttp.Client) (*bytes.Reader, error) { - if len(templateURI) == 0 { - return bytes.NewReader(make([]byte, 0)), nil + return err } - f := fetcher.NewFetcher(fetcher.WithClient(hc)) - - template, err := f.Fetch(templateURI) - if errors.Is(err, fetcher.ErrUnknownScheme) { - // legacy filepath - templateURI = "file://" + templateURI - l.WithError(err).Warnf("support for filepaths without a 'file://' scheme will be dropped in the next release, please use %s instead in your config", templateURI) - template, err = f.Fetch(templateURI) - } - // this handles the first error if it is a known scheme error, or the second fetch error + req, err := builder.BuildRequest(data) if err != nil { - return nil, err - } - - vm := jsonnet.MakeVM() - - buf := new(bytes.Buffer) - enc := json.NewEncoder(buf) - enc.SetEscapeHTML(false) - enc.SetIndent("", "") - - if err := enc.Encode(data); err != nil { - return nil, err - } - vm.TLACode("ctx", buf.String()) - - if res, err := vm.EvaluateAnonymousSnippet(templateURI, template.String()); err != nil { - return nil, err - } else { - return bytes.NewReader([]byte(res)), nil + return err } -} -func doHttpCall(ctx context.Context, method string, url string, as AuthStrategy, body io.Reader, hc *retryablehttp.Client) error { - req, err := retryablehttp.NewRequest(method, url, body) - req = req.WithContext(ctx) + resp, err := e.deps.HTTPClient(ctx).Do(req) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - as.apply(req) - - resp, err := hc.Do(req) - - if err != nil { - return err - } else if resp.StatusCode >= 400 { + if resp.StatusCode >= http.StatusBadRequest { return fmt.Errorf("web hook failed with status code %v", resp.StatusCode) } diff --git a/selfservice/hook/web_hook_test.go b/selfservice/hook/web_hook_test.go deleted file mode 100644 index 41b0ba135f02..000000000000 --- a/selfservice/hook/web_hook_test.go +++ /dev/null @@ -1,268 +0,0 @@ -package hook - -import ( - _ "embed" - "encoding/base64" - "encoding/json" - "io" - "net/http" - "testing" - - "github.com/hashicorp/go-retryablehttp" - - "github.com/sirupsen/logrus/hooks/test" - - "github.com/ory/x/logrusx" - - "github.com/ory/kratos/identity" - "github.com/ory/kratos/x" - - "github.com/stretchr/testify/require" - - "github.com/ory/kratos/selfservice/flow/login" - - "github.com/stretchr/testify/assert" -) - -func TestNoopAuthStrategy(t *testing.T) { - req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} - auth := NoopAuthStrategy{} - - auth.apply(&req) - - assert.Empty(t, req.Header, "Empty auth strategy shall not modify any request headers") -} - -func TestBasicAuthStrategy(t *testing.T) { - req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} - auth := BasicAuthStrategy{ - user: "test-user", - password: "test-pass", - } - - auth.apply(&req) - - assert.Len(t, req.Header, 1) - - user, pass, _ := req.BasicAuth() - assert.Equal(t, "test-user", user) - assert.Equal(t, "test-pass", pass) -} - -func TestApiKeyInHeaderStrategy(t *testing.T) { - req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} - auth := ApiKeyStrategy{ - in: "header", - name: "my-api-key-name", - value: "my-api-key-value", - } - - auth.apply(&req) - - require.Len(t, req.Header, 1) - - actualValue := req.Header.Get("my-api-key-name") - assert.Equal(t, "my-api-key-value", actualValue) -} - -func TestApiKeyInCookieStrategy(t *testing.T) { - req := retryablehttp.Request{Request: &http.Request{Header: map[string][]string{}}} - auth := ApiKeyStrategy{ - in: "cookie", - name: "my-api-key-name", - value: "my-api-key-value", - } - - auth.apply(&req) - - cookies := req.Cookies() - assert.Len(t, cookies, 1) - - assert.Equal(t, "my-api-key-name", cookies[0].Name) - assert.Equal(t, "my-api-key-value", cookies[0].Value) -} - -//go:embed stub/test_body.jsonnet -var testBodyJSONNet []byte - -func TestJsonNetSupport(t *testing.T) { - f := &login.Flow{ID: x.NewUUID()} - i := identity.NewIdentity("") - l := logrusx.New("kratos", "test") - - for _, tc := range []struct { - desc, template string - data *templateContext - }{ - { - desc: "simple file URI", - template: "file://./stub/test_body.jsonnet", - data: &templateContext{ - Flow: f, - RequestHeaders: http.Header{ - "Cookie": []string{"c1=v1", "c2=v2"}, - "Some-Header": []string{"Some-Value"}, - }, - RequestMethod: "POST", - RequestUrl: "https://test.kratos.ory.sh/some-test-path", - Identity: i, - }, - }, - { - desc: "legacy filepath without scheme", - template: "./stub/test_body.jsonnet", - data: &templateContext{ - Flow: f, - RequestHeaders: http.Header{ - "Cookie": []string{"c1=v1", "c2=v2"}, - "Some-Header": []string{"Some-Value"}, - }, - RequestMethod: "POST", - RequestUrl: "https://test.kratos.ory.sh/some-test-path", - Identity: i, - }, - }, - { - desc: "base64 encoded template URI", - template: "base64://" + base64.StdEncoding.EncodeToString(testBodyJSONNet), - data: &templateContext{ - Flow: f, - RequestHeaders: http.Header{ - "Cookie": []string{"foo=bar"}, - "My-Custom-Header": []string{"Cumstom-Value"}, - }, - RequestMethod: "PUT", - RequestUrl: "https://test.kratos.ory.sh/other-test-path", - Identity: i, - }, - }, - } { - t.Run("case="+tc.desc, func(t *testing.T) { - b, err := createBody(l, tc.template, tc.data, retryablehttp.NewClient()) - require.NoError(t, err) - body, err := io.ReadAll(b) - require.NoError(t, err) - - expected, err := json.Marshal(map[string]interface{}{ - "flow_id": tc.data.Flow.GetID(), - "identity_id": tc.data.Identity.ID, - "headers": tc.data.RequestHeaders, - "method": tc.data.RequestMethod, - "url": tc.data.RequestUrl, - }) - require.NoError(t, err) - - assert.JSONEq(t, string(expected), string(body)) - }) - } - - t.Run("case=warns about legacy usage", func(t *testing.T) { - hook := test.Hook{} - l := logrusx.New("kratos", "test", logrusx.WithHook(&hook)) - - _, _ = createBody(l, "./foo", nil, retryablehttp.NewClient()) - - require.Len(t, hook.Entries, 1) - assert.Contains(t, hook.LastEntry().Message, "support for filepaths without a 'file://' scheme will be dropped") - }) - - t.Run("case=return non nil body reader on empty templateURI", func(t *testing.T) { - body, err := createBody(l, "", nil, retryablehttp.NewClient()) - assert.NotNil(t, body) - assert.Nil(t, err) - }) -} - -func TestWebHookConfig(t *testing.T) { - for _, tc := range []struct { - strategy string - method string - url string - body string - rawConfig string - authStrategy AuthStrategy - }{ - { - strategy: "empty", - method: "POST", - url: "https://test.kratos.ory.sh/my_hook1", - body: "/path/to/my/jsonnet1.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook1", - "method": "POST", - "body": "/path/to/my/jsonnet1.file" - }`, - authStrategy: &NoopAuthStrategy{}, - }, - { - strategy: "basic_auth", - method: "GET", - url: "https://test.kratos.ory.sh/my_hook2", - body: "/path/to/my/jsonnet2.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook2", - "method": "GET", - "body": "/path/to/my/jsonnet2.file", - "auth": { - "type": "basic_auth", - "config": { - "user": "test-api-user", - "password": "secret" - } - } - }`, - authStrategy: &BasicAuthStrategy{}, - }, - { - strategy: "api-key/header", - method: "DELETE", - url: "https://test.kratos.ory.sh/my_hook3", - body: "/path/to/my/jsonnet3.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook3", - "method": "DELETE", - "body": "/path/to/my/jsonnet3.file", - "auth": { - "type": "api_key", - "config": { - "in": "header", - "name": "my-api-key", - "value": "secret" - } - } - }`, - authStrategy: &ApiKeyStrategy{}, - }, - { - strategy: "api-key/cookie", - method: "POST", - url: "https://test.kratos.ory.sh/my_hook4", - body: "/path/to/my/jsonnet4.file", - rawConfig: `{ - "url": "https://test.kratos.ory.sh/my_hook4", - "method": "POST", - "body": "/path/to/my/jsonnet4.file", - "auth": { - "type": "api_key", - "config": { - "in": "cookie", - "name": "my-api-key", - "value": "secret" - } - } - }`, - authStrategy: &ApiKeyStrategy{}, - }, - } { - t.Run("auth-strategy="+tc.strategy, func(t *testing.T) { - conf, err := newWebHookConfig([]byte(tc.rawConfig)) - assert.Nil(t, err) - - assert.Equal(t, tc.url, conf.URL) - assert.Equal(t, tc.method, conf.Method) - assert.Equal(t, tc.body, conf.TemplateURI) - assert.NotNil(t, conf.Auth) - assert.IsTypef(t, tc.authStrategy, conf.Auth, "Auth should be of the expected type") - }) - } -} diff --git a/selfservice/strategy/link/sender.go b/selfservice/strategy/link/sender.go index f6567467f7f0..df4a9bf2b851 100644 --- a/selfservice/strategy/link/sender.go +++ b/selfservice/strategy/link/sender.go @@ -7,6 +7,8 @@ import ( "github.com/hashicorp/go-retryablehttp" + "github.com/ory/kratos/courier/template/email" + "github.com/ory/x/httpx" "github.com/pkg/errors" @@ -16,7 +18,6 @@ import ( "github.com/ory/x/urlx" "github.com/ory/kratos/courier" - templates "github.com/ory/kratos/courier/template" "github.com/ory/kratos/driver/config" "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/flow/recovery" @@ -66,7 +67,7 @@ func (s *Sender) SendRecoveryLink(ctx context.Context, r *http.Request, f *recov address, err := s.r.IdentityPool().FindRecoveryAddressByValue(ctx, identity.RecoveryAddressTypeEmail, to) if err != nil { - if err := s.send(ctx, string(via), templates.NewRecoveryInvalid(s.r, &templates.RecoveryInvalidModel{To: to})); err != nil { + if err := s.send(ctx, string(via), email.NewRecoveryInvalid(s.r, &email.RecoveryInvalidModel{To: to})); err != nil { return err } return errors.Cause(ErrUnknownAddress) @@ -106,7 +107,7 @@ func (s *Sender) SendVerificationLink(ctx context.Context, f *verification.Flow, WithField("via", via). WithSensitiveField("email_address", address). Info("Sending out invalid verification email because address is unknown.") - if err := s.send(ctx, string(via), templates.NewVerificationInvalid(s.r, &templates.VerificationInvalidModel{To: to})); err != nil { + if err := s.send(ctx, string(via), email.NewVerificationInvalid(s.r, &email.VerificationInvalidModel{To: to})); err != nil { return err } return errors.Cause(ErrUnknownAddress) @@ -145,8 +146,8 @@ func (s *Sender) SendRecoveryTokenTo(ctx context.Context, f *recovery.Flow, i *i return err } - return s.send(ctx, string(address.Via), templates.NewRecoveryValid(s.r, - &templates.RecoveryValidModel{To: address.Value, RecoveryURL: urlx.CopyWithQuery( + return s.send(ctx, string(address.Via), email.NewRecoveryValid(s.r, + &email.RecoveryValidModel{To: address.Value, RecoveryURL: urlx.CopyWithQuery( urlx.AppendPaths(s.r.Config(ctx).SelfServiceLinkMethodBaseURL(), recovery.RouteSubmitFlow), url.Values{ "token": {token.Token}, @@ -168,8 +169,8 @@ func (s *Sender) SendVerificationTokenTo(ctx context.Context, f *verification.Fl return err } - if err := s.send(ctx, string(address.Via), templates.NewVerificationValid(s.r, - &templates.VerificationValidModel{To: address.Value, VerificationURL: urlx.CopyWithQuery( + if err := s.send(ctx, string(address.Via), email.NewVerificationValid(s.r, + &email.VerificationValidModel{To: address.Value, VerificationURL: urlx.CopyWithQuery( urlx.AppendPaths(s.r.Config(ctx).SelfServiceLinkMethodBaseURL(), verification.RouteSubmitFlow), url.Values{ "flow": {f.ID.String()}, diff --git a/test/schema/fixtures/config.schema.test.failure/root.SMSConfigmalformedURL.yaml b/test/schema/fixtures/config.schema.test.failure/root.SMSConfigmalformedURL.yaml new file mode 100644 index 000000000000..f5b1003cff6d --- /dev/null +++ b/test/schema/fixtures/config.schema.test.failure/root.SMSConfigmalformedURL.yaml @@ -0,0 +1,5 @@ +sms: + request_config: + url: "malformed uri" + method: POST + body: "malformed uri" diff --git a/test/schema/fixtures/config.schema.test.success/root.courierSMS.yaml b/test/schema/fixtures/config.schema.test.success/root.courierSMS.yaml new file mode 100644 index 000000000000..d8c9707ed891 --- /dev/null +++ b/test/schema/fixtures/config.schema.test.success/root.courierSMS.yaml @@ -0,0 +1,23 @@ +selfservice: + default_browser_return_url: "#/definitions/defaultReturnTo" + +dsn: foo + +identity: + schemas: + - id: default + url: https://example.com + +courier: + smtp: + connection_uri: smtps://foo:bar@my-mailserver:1234/ + from_address: no-reply@ory.kratos.sh + sms: + enabled: true + from: "+19592155527" + request_config: + url: https://sms.example.com + method: POST + body: file://request.config.twilio.jsonnet + header: + 'Content-Type': "application/x-www-form-urlencoded"