Skip to content

Commit

Permalink
feat: add custom sms hook (supabase#1474)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

Allows devs to use a Custom SMS Provider via a HTTP Hook. SQL Hooks are
not supported at this time as it requires significantly more effort to
invoke a Hook via `pg_net` and handle errors as compared to using
HTTP/HTTPS.

HTTP Hook invocation is done via the
[standardwebhooks](www.standardwebhooks.com) library for symmetric
hooks. Asymmetric hooks are being finalized by the committee and support
will be added shortly.


The following changes are being made from the internal RFC:
- [x] Hooks have a timeout of 5 seconds instead of 15 seconds so as not
to run for too long
- [x] Hooks have a size limit of 20kb. This is not stated in internal
RFC but is part of the recommendations under the standard webhooks RFC.
- [x] We allow hooks using the `http` protocol . This is to support
local development. Restriction to https can be done on the dashboard
page for connecting a hook.
- [x] Add log statements where relevant and write proper error messages
- [x] Add more Gock tests

---------

Co-authored-by: Stojan Dimitrovski <[email protected]>
  • Loading branch information
J0 and hf authored Mar 27, 2024
1 parent 2843dbe commit 0f6b29a
Show file tree
Hide file tree
Showing 18 changed files with 412 additions and 31 deletions.
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ require (
github.com/fatih/structs v1.1.0
github.com/gobuffalo/pop/v6 v6.1.1
github.com/jackc/pgx/v4 v4.18.2
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721
github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869
github.com/supabase/mailme v0.0.0-20230628061017-01f68480c747
github.com/xeipuuv/gojsonschema v1.2.0
Expand Down Expand Up @@ -146,4 +147,6 @@ require (
gopkg.in/yaml.v3 v3.0.1 // indirect
)

go 1.21
go 1.21.0

toolchain go1.21.6
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUq
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 h1:HTsFo0buahHfjuVUTPDdJRBkfjExkRM1LUBy6crQ7lc=
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
Expand Down
4 changes: 4 additions & 0 deletions internal/api/errorcodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,8 @@ const (
ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit"
ErrorBadCodeVerifier ErrorCode = "bad_code_verifier"
ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled"
ErrorCodeHookTimeout ErrorCode = "hook_timeout"
ErrorCodeHookTimeoutAfterRetry ErrorCode = "hook_timeout_after_retry"
ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit"
ErrorCodeHookPayloadUnknownSize ErrorCode = "hook_payload_unknown_size"
)
2 changes: 1 addition & 1 deletion internal/api/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestIsValidCodeChallenge(t *testing.T) {
}
}

func TestIsValidPKCEParmas(t *testing.T) {
func TestIsValidPKCEParams(t *testing.T) {
cases := []struct {
challengeMethod string
challenge string
Expand Down
159 changes: 152 additions & 7 deletions internal/api/hooks.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
package api

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"

"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/observability"

"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/crypto"

"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/hooks"

"github.com/supabase/auth/internal/storage"
)

const (
DefaultHTTPHookTimeout = 5 * time.Second
DefaultHTTPHookRetries = 3
HTTPHookBackoffDuration = 2 * time.Second
PayloadLimit = 200 * 1024 // 200KB
)

func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) {
db := a.db.WithContext(ctx)

Expand Down Expand Up @@ -55,20 +75,145 @@ func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name
return response, nil
}

// invokeHook invokes the hook code. tx can be nil, in which case a new
func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
client := http.Client{
Timeout: DefaultHTTPHookTimeout,
}
ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout)
defer cancel()

log := observability.GetLogEntry(r)
requestURL := hookConfig.URI
hookLog := log.WithFields(logrus.Fields{
"component": "auth_hook",
"url": requestURL,
})

inputPayload, err := json.Marshal(input)
if err != nil {
return nil, err
}
for i := 0; i < DefaultHTTPHookRetries; i++ {
if i == 0 {
hookLog.Debugf("invocation attempt: %d", i)
} else {
hookLog.Infof("invocation attempt: %d", i)
}
msgID := uuid.Must(uuid.NewV4())
currentTime := time.Now()
signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload)
if err != nil {
return nil, err
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(inputPayload))
if err != nil {
panic("Failed to make request object")
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("webhook-id", msgID.String())
req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix()))
req.Header.Set("webhook-signature", strings.Join(signatureList, ", "))
// By default, Go Client sets encoding to gzip, which does not carry a content length header.
req.Header.Set("Accept-Encoding", "identity")

rsp, err := client.Do(req)
if err != nil && errors.Is(err, context.DeadlineExceeded) {
return nil, unprocessableEntityError(ErrorCodeHookTimeout, fmt.Sprintf("Failed to reach hook within maximum time of %f seconds", DefaultHTTPHookTimeout.Seconds()))

} else if err != nil {
if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 {
hookLog.Errorf("Request timed out for attempt %d with err %s", i, err)
time.Sleep(HTTPHookBackoffDuration)
continue
} else if i == DefaultHTTPHookRetries-1 {
return nil, unprocessableEntityError(ErrorCodeHookTimeoutAfterRetry, "Failed to reach hook after maximum retries")
} else {
return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err)
}
}

defer rsp.Body.Close()

switch rsp.StatusCode {
case http.StatusOK, http.StatusNoContent, http.StatusAccepted:
if rsp.Body == nil {
return nil, nil
}
contentLength := rsp.ContentLength
if contentLength == -1 {
return nil, unprocessableEntityError(ErrorCodeHookPayloadUnknownSize, "Payload size not known")
}
if contentLength >= PayloadLimit {
return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, fmt.Sprintf("Payload size is: %d bytes exceeded size limit of %d bytes", contentLength, PayloadLimit))
}
limitedReader := io.LimitedReader{R: rsp.Body, N: contentLength}
body, err := io.ReadAll(&limitedReader)
if err != nil {
return nil, err
}
return body, nil
case http.StatusTooManyRequests, http.StatusServiceUnavailable:
retryAfterHeader := rsp.Header.Get("retry-after")
// Check for truthy values to allow for flexibility to switch to time duration
if retryAfterHeader != "" {
continue
}
return nil, internalServerError("Service currently unavailable due to hook")
case http.StatusBadRequest:
return nil, internalServerError("Invalid payload sent to hook")
case http.StatusUnauthorized:
return nil, internalServerError("Hook requires authorization token")
default:
return nil, internalServerError("Error executing Hook")
}
}
return nil, nil
}

func (a *API) invokeHTTPHook(ctx context.Context, r *http.Request, input, output any, hookURI string) error {
switch input.(type) {
case *hooks.CustomSMSProviderInput:
hookOutput, ok := output.(*hooks.CustomSMSProviderOutput)
if !ok {
panic("output should be *hooks.CustomSMSProviderOutput")
}
var response []byte
var err error

if response, err = a.runHTTPHook(ctx, r, a.config.Hook.CustomSMSProvider, input, output); err != nil {
return internalServerError("Error invoking custom SMS provider hook.").WithInternalError(err)
}
if err != nil {
return err
}

if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling custom SMS provider hook output.").WithInternalError(err)
}

default:
panic("unknown HTTP hook type")
}
return nil
}

// invokePostgresHook invokes the hook code. tx can be nil, in which case a new
// transaction is opened. If calling invokeHook within a transaction, always
// pass the current transaciton, as pool-exhaustion deadlocks are very easy to
// pass the current transaction, as pool-exhaustion deadlocks are very easy to
// trigger.
func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, output any) error {
func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, input, output any, hookURI string) error {
config := a.config
// Switch based on hook type
switch input.(type) {
case *hooks.MFAVerificationAttemptInput:
hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.MFAVerificationAttemptOutput")
}

if _, err := a.runPostgresHook(ctx, tx, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking MFA verification hook.").WithInternalError(err)
}

Expand All @@ -94,7 +239,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
panic("output should be *hooks.PasswordVerificationAttemptOutput")
}

if _, err := a.runPostgresHook(ctx, tx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking password verification hook.").WithInternalError(err)
}

Expand All @@ -120,7 +265,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
panic("output should be *hooks.CustomAccessTokenOutput")
}

if _, err := a.runPostgresHook(ctx, tx, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
return internalServerError("Error invoking access token hook.").WithInternalError(err)
}

Expand Down Expand Up @@ -155,6 +300,6 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
return nil

default:
panic("unknown hook input type")
panic("unknown Postgres hook input type")
}
}
Loading

0 comments on commit 0f6b29a

Please sign in to comment.