Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add custom sms hook #1474

Merged
merged 12 commits into from
Mar 27, 2024
Merged
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ require (
github.com/fatih/structs v1.1.0
github.com/gobuffalo/pop/v6 v6.1.1
github.com/jackc/pgx/v4 v4.18.2
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721
github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869
github.com/supabase/mailme v0.0.0-20230628061017-01f68480c747
github.com/xeipuuv/gojsonschema v1.2.0
Expand Down Expand Up @@ -146,4 +147,6 @@ require (
gopkg.in/yaml.v3 v3.0.1 // indirect
)

go 1.21
go 1.21.0

toolchain go1.21.6
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUq
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 h1:HTsFo0buahHfjuVUTPDdJRBkfjExkRM1LUBy6crQ7lc=
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
Expand Down
1 change: 1 addition & 0 deletions internal/api/errorcodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,5 @@ const (
ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit"
ErrorBadCodeVerifier ErrorCode = "bad_code_verifier"
ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled"
ErrorHookTimeout ErrorCode = "hook_timeout"
)
4 changes: 4 additions & 0 deletions internal/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ func conflictError(fmtString string, args ...interface{}) *HTTPError {
return httpError(http.StatusConflict, ErrorCodeConflict, fmtString, args...)
}

func gatewayTimeoutError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError {
J0 marked this conversation as resolved.
Show resolved Hide resolved
return httpError(http.StatusGatewayTimeout, errorCode, fmtString, args...)
}

// HTTPError is an error with a message and an HTTP status code.
type HTTPError struct {
HTTPStatus int `json:"code"` // do not rename the JSON tags!
Expand Down
189 changes: 182 additions & 7 deletions internal/api/hooks.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,35 @@
package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"strings"
"time"

"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/observability"

"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/crypto"

"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/hooks"

"github.com/supabase/auth/internal/storage"
)

const (
DefaultHTTPHookTimeout = 5 * time.Second
DefaultHTTPHookRetries = 3
HTTPHookBackoffDuration = 2 * time.Second
)

func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) {
db := a.db.WithContext(ctx)

Expand Down Expand Up @@ -55,20 +74,176 @@ func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name
return response, nil
}

// invokeHook invokes the hook code. tx can be nil, in which case a new
func readBodyWithLimit(rsp *http.Response) ([]byte, error) {
J0 marked this conversation as resolved.
Show resolved Hide resolved
defer rsp.Body.Close()

const limit = 20 * 1024 // 20KB
J0 marked this conversation as resolved.
Show resolved Hide resolved
limitedReader := io.LimitedReader{R: rsp.Body, N: limit}

body, err := io.ReadAll(&limitedReader)
if err != nil {
return nil, err
}

if limitedReader.N <= 0 {
// Attempt to read one more byte to check if we're exactly at the limit or over
_, err := rsp.Body.Read(make([]byte, 1))
if err == nil {
// If we could read more, then the payload was too large
return nil, fmt.Errorf("payload too large")
}
}

return body, nil
}

func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
client := http.Client{
Timeout: DefaultHTTPHookTimeout,
}
log := observability.GetLogEntry(r)
requestURL := hookConfig.URI
hookLog := log.WithFields(logrus.Fields{
"component": "auth_hook",
"url": requestURL,
})

inputPayload, err := json.Marshal(input)
if err != nil {
return nil, err
}
start := time.Now()
for i := 0; i < DefaultHTTPHookRetries; i++ {
hookLog.Infof("invocation attempt: %d", i)
J0 marked this conversation as resolved.
Show resolved Hide resolved
if time.Since(start) > time.Duration(i+1)*DefaultHTTPHookTimeout {
return []byte{}, gatewayTimeoutError(ErrorHookTimeout, "failed to reach hook within timeout")
}
msgID := uuid.Must(uuid.NewV4())
currentTime := time.Now()
signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload)
if err != nil {
return nil, err
}

req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewBuffer(inputPayload))
J0 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, internalServerError("Failed to make request object").WithInternalError(err)
J0 marked this conversation as resolved.
Show resolved Hide resolved
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("webhook-id", msgID.String())
req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix()))
req.Header.Set("webhook-signature", strings.Join(signatureList, ", "))

watcher, req := watchForConnection(req)
rsp, err := client.Do(req)

if err != nil {
J0 marked this conversation as resolved.
Show resolved Hide resolved
if terr, ok := err.(net.Error); ok && terr.Timeout() {
hookLog.Errorf("Request timed out for attempt %d with err %s", i, err)
time.Sleep(HTTPHookBackoffDuration)
continue
} else if !watcher.gotConn && i < DefaultHTTPHookRetries-1 {
hookLog.Errorf("Failed to establish a connection on attempt %d with err %s", i, err)
time.Sleep(HTTPHookBackoffDuration)
continue
} else if i == DefaultHTTPHookRetries-1 {
return nil, gatewayTimeoutError(ErrorHookTimeout, "Failed to reach hook within allotted interval")

} else {
return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err)
}
}
J0 marked this conversation as resolved.
Show resolved Hide resolved

switch rsp.StatusCode {
case http.StatusOK, http.StatusNoContent, http.StatusAccepted:
if rsp.Body == nil {
return nil, nil
}
body, err := readBodyWithLimit(rsp)
if err != nil {
return nil, err
}
return body, nil
case http.StatusTooManyRequests, http.StatusServiceUnavailable:
retryAfterHeader := rsp.Header.Get("retry-after")
// Check for truthy values to allow for flexibility to swtich to time duration
if retryAfterHeader != "" {
continue
}
return []byte{}, internalServerError("Service currently unavailable")
J0 marked this conversation as resolved.
Show resolved Hide resolved
case http.StatusBadRequest:
return nil, badRequestError(ErrorCodeValidationFailed, "Invalid payload sent to hook")
J0 marked this conversation as resolved.
Show resolved Hide resolved
case http.StatusUnauthorized:
return []byte{}, httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "Hook requires authorizaition token")
default:
return []byte{}, internalServerError("Error executing Hook")
J0 marked this conversation as resolved.
Show resolved Hide resolved
}
}
return nil, internalServerError("error executing hook")
J0 marked this conversation as resolved.
Show resolved Hide resolved
}

func watchForConnection(req *http.Request) (*connectionWatcher, *http.Request) {
J0 marked this conversation as resolved.
Show resolved Hide resolved
w := new(connectionWatcher)
t := &httptrace.ClientTrace{
GotConn: w.GotConn,
}

req = req.WithContext(httptrace.WithClientTrace(req.Context(), t))
return w, req
}

type connectionWatcher struct {
gotConn bool
}

func (c *connectionWatcher) GotConn(_ httptrace.GotConnInfo) {
c.gotConn = true
}

func (a *API) invokeHTTPHook(r *http.Request, input, output any, hookURI string) error {
switch input.(type) {
case *hooks.CustomSMSProviderInput:
hookOutput, ok := output.(*hooks.CustomSMSProviderOutput)
if !ok {
panic("output should be *hooks.CustomSMSProviderOutput")
}
var response []byte
var err error

if response, err = a.runHTTPHook(r, a.config.Hook.CustomSMSProvider, input, output); err != nil {
return internalServerError("Error invoking custom SMS provider hook.").WithInternalError(err)
}
if err != nil {
return err
}

if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling custom SMS provider hook output.").WithInternalError(err)
}
fmt.Printf("%v", hookOutput)
J0 marked this conversation as resolved.
Show resolved Hide resolved

default:
panic("unknown HTTP hook type")
}
return nil
}

// invokePostgresHook invokes the hook code. tx can be nil, in which case a new
// transaction is opened. If calling invokeHook within a transaction, always
// pass the current transaciton, as pool-exhaustion deadlocks are very easy to
// pass the current transaction, as pool-exhaustion deadlocks are very easy to
// trigger.
func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, output any) error {
func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, input, output any, hookURI string) error {
config := a.config
// Switch based on hook type
switch input.(type) {
case *hooks.MFAVerificationAttemptInput:
hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.MFAVerificationAttemptOutput")
}

if _, err := a.runPostgresHook(ctx, tx, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking MFA verification hook.").WithInternalError(err)
}

Expand All @@ -94,7 +269,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
panic("output should be *hooks.PasswordVerificationAttemptOutput")
}

if _, err := a.runPostgresHook(ctx, tx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking password verification hook.").WithInternalError(err)
}

Expand All @@ -120,7 +295,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
panic("output should be *hooks.CustomAccessTokenOutput")
}

if _, err := a.runPostgresHook(ctx, tx, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
return internalServerError("Error invoking access token hook.").WithInternalError(err)
}

Expand Down Expand Up @@ -155,6 +330,6 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
return nil

default:
panic("unknown hook input type")
panic("unknown Postgres hook input type")
}
}
Loading
Loading