Skip to content

Commit

Permalink
fix: modify read checks
Browse files Browse the repository at this point in the history
  • Loading branch information
J0 committed Mar 17, 2024
1 parent 97884cc commit e885fe3
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 35 deletions.
61 changes: 38 additions & 23 deletions internal/api/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"time"

"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/observability"

"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/crypto"

Expand All @@ -24,8 +26,9 @@ import (
)

const (
DefaultHTTPHookTimeout = 5 * time.Second
DefaultHTTPHookRetries = 3
DefaultHTTPHookTimeout = 5 * time.Second
DefaultHTTPHookRetries = 3
HTTPHookBackoffDuration = 2 * time.Second
)

func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) {
Expand Down Expand Up @@ -72,13 +75,37 @@ func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name
return response, nil
}

func (a *API) runHTTPHook(hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
func readBodyWithLimit(rsp *http.Response) ([]byte, error) {
defer rsp.Body.Close()

const limit = 20 * 1024 // 20KB
limitedReader := io.LimitedReader{R: rsp.Body, N: limit}

body, err := io.ReadAll(&limitedReader)
if err != nil {
return nil, err
}

if limitedReader.N <= 0 {
// Attempt to read one more byte to check if we're exactly at the limit or over
_, err := rsp.Body.Read(make([]byte, 1))
if err == nil {
// If we could read more, then the payload was too large
return nil, fmt.Errorf("payload too large")
}
}

return body, nil
}

func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {

client := http.Client{
Timeout: DefaultHTTPHookTimeout,
}
log := observability.GetLogEntry(r)
requestURL := hookConfig.URI
hookLog := logrus.WithFields(logrus.Fields{
hookLog := log.WithFields(logrus.Fields{
"component": "auth_hook",
"url": requestURL,
})
Expand All @@ -90,7 +117,7 @@ func (a *API) runHTTPHook(hookConfig conf.ExtensibilityPointConfiguration, input
start := time.Now()
for i := 0; i < DefaultHTTPHookRetries; i++ {
hookLog.Infof("invocation attempt: %d", i)
if time.Since(start) > time.Duration(i)*DefaultHTTPHookTimeout {
if time.Since(start) > time.Duration(i+1)*DefaultHTTPHookTimeout {
return []byte{}, gatewayTimeoutError(ErrorHookGatewayTimeout, "failed to reach hook within timeout")
}
msgID := uuid.Must(uuid.NewV4())
Expand All @@ -106,7 +133,6 @@ func (a *API) runHTTPHook(hookConfig conf.ExtensibilityPointConfiguration, input
}

req.Header.Set("Content-Type", "application/json")

req.Header.Set("webhook-id", msgID.String())
req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix()))
req.Header.Set("webhook-signature", strings.Join(signatureList, ", "))
Expand All @@ -117,12 +143,11 @@ func (a *API) runHTTPHook(hookConfig conf.ExtensibilityPointConfiguration, input
if err != nil {
if terr, ok := err.(net.Error); ok && terr.Timeout() {
hookLog.Errorf("Request timed out for attempt %d with err %s", i, err)
// TODO: workshop the sleep time value
time.Sleep(DefaultHTTPHookTimeout)
time.Sleep(HTTPHookBackoffDuration)
continue
} else if !watcher.gotConn && i < DefaultHTTPHookRetries-1 {
hookLog.Errorf("Failed to establish a connection on attempt %d with err %s", i, err)
time.Sleep(DefaultHTTPHookTimeout)
time.Sleep(HTTPHookBackoffDuration)
continue
} else if i == DefaultHTTPHookRetries-1 {
return nil, gatewayTimeoutError(ErrorHookGatewayTimeout, "failed to reach hook within allotted interval")
Expand All @@ -137,18 +162,14 @@ func (a *API) runHTTPHook(hookConfig conf.ExtensibilityPointConfiguration, input
if rsp.Body == nil {
return nil, nil
}
defer rsp.Body.Close()
body, err := io.ReadAll(rsp.Body)
body, err := readBodyWithLimit(rsp)
if err != nil {
return nil, err
}
if isOverSizeLimit(body) {
return nil, internalServerError("payload too large")
}
return body, nil
case http.StatusTooManyRequests, http.StatusServiceUnavailable:
retryAfterHeader := rsp.Header.Get("retry-after")
if retryAfterHeader != "true" {
if retryAfterHeader != "" {
continue
}
return []byte{}, internalServerError("Service unavailable")
Expand Down Expand Up @@ -181,12 +202,6 @@ func (c *connectionWatcher) GotConn(_ httptrace.GotConnInfo) {
c.gotConn = true
}

func isOverSizeLimit(payload []byte) bool {
// As per spec, payloads should be under 20KB https://github.com/standard-webhooks/standard-webhooks/blob/main/spec/standard-webhooks.md#payload-size
const maxSizeKB = 20 * 1024
return len(payload) > maxSizeKB
}

func validateHTTPHook(uri string) error {
u, err := url.Parse(uri)
if err != nil {
Expand All @@ -209,7 +224,7 @@ func validatePostgresHook(uri string) error {
return nil
}

func (a *API) invokeHTTPHook(input, output any, hookURI string) error {
func (a *API) invokeHTTPHook(r *http.Request, input, output any, hookURI string) error {
if err := validateHTTPHook(hookURI); err != nil {
return err
}
Expand All @@ -222,7 +237,7 @@ func (a *API) invokeHTTPHook(input, output any, hookURI string) error {
var response []byte
var err error

if response, err = a.runHTTPHook(a.config.Hook.CustomSMSProvider, input, output); err != nil {
if response, err = a.runHTTPHook(r, a.config.Hook.CustomSMSProvider, input, output); err != nil {
return internalServerError("Error invoking custom SMS provider hook.").WithInternalError(err)
}
if err != nil {
Expand Down
10 changes: 7 additions & 3 deletions internal/api/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ func (ts *HooksTestSuite) TestRunHTTPHook() {
},
{
description: "Too many requests with retry-after header",
status: http.StatusTooManyRequests,
status: http.StatusGatewayTimeout,
matchHeader: map[string]string{"retry-after": "true"},
expectError: true,
},
{
// TODO: maybe properly check for a condition in the w/o retry case
description: "Too many requests without retry header should not retry",
status: http.StatusTooManyRequests,
status: http.StatusGatewayTimeout,
expectError: true,
},
}
Expand All @@ -104,7 +104,11 @@ func (ts *HooksTestSuite) TestRunHTTPHook() {

var output hooks.CustomSMSProviderOutput

body, err := ts.API.runHTTPHook(ts.Config.Hook.CustomSMSProvider, &input, &output)
// Mock of original HTTP Rquest which triggered the hook
req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil)
require.NoError(ts.T(), err)

body, err := ts.API.runHTTPHook(req, ts.Config.Hook.CustomSMSProvider, &input, &output)
if err == nil && body != nil {
err = json.Unmarshal(body, &output)
require.NoError(ts.T(), err, "Unmarshal should not fail")
Expand Down
2 changes: 1 addition & 1 deletion internal/api/otp.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
return internalServerError("Unable to get SMS provider").WithInternalError(err)
}
mID, serr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel)
mID, serr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel)
if serr != nil {
return badRequestError(ErrorCodeSMSSendFailed, "Error sending sms OTP: %v", serr).WithInternalError(serr)
}
Expand Down
5 changes: 3 additions & 2 deletions internal/api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"bytes"
"github.com/supabase/auth/internal/hooks"
"net/http"
"regexp"
"strings"
"text/template"
Expand Down Expand Up @@ -41,7 +42,7 @@ func formatPhoneNumber(phone string) string {
}

// sendPhoneConfirmation sends an otp to the user's phone number
func (a *API) sendPhoneConfirmation(tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) {
func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) {
config := a.config

var token *string
Expand Down Expand Up @@ -99,7 +100,7 @@ func (a *API) sendPhoneConfirmation(tx *storage.Connection, user *models.User, p
OTP: otp,
}
output := hooks.CustomSMSProviderOutput{}
err := a.invokeHTTPHook(&input, &output, config.Hook.CustomSMSProvider.URI)
err := a.invokeHTTPHook(r, &input, &output, config.Hook.CustomSMSProvider.URI)
if err != nil {
return "", err
}
Expand Down
4 changes: 3 additions & 1 deletion internal/api/phone_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ func (ts *PhoneTestSuite) TestFormatPhoneNumber() {

func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) {
u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud)
req, _ := http.NewRequest("POST", "http://localhost:9998/otp", nil)

require.NoError(ts.T(), err)
cases := []struct {
desc string
Expand Down Expand Up @@ -111,7 +113,7 @@ func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) {
ts.Run(c.desc, func() {
provider := &TestSmsProvider{}

_, err = ts.API.sendPhoneConfirmation(ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider)
_, err = ts.API.sendPhoneConfirmation(req, ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider)
require.Equal(ts.T(), c.expected, err)
u, err = models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
Expand Down
2 changes: 1 addition & 1 deletion internal/api/reauthenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
return internalServerError("Failed to get SMS provider").WithInternalError(terr)
}
mID, err := a.sendPhoneConfirmation(tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider)
mID, err := a.sendPhoneConfirmation(r, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/resend.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
return terr
}
mID, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider)
mID, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider)
if terr != nil {
return terr
}
Expand All @@ -146,7 +146,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
return terr
}
mID, terr := a.sendPhoneConfirmation(tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider)
mID, terr := a.sendPhoneConfirmation(r, tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider)
if terr != nil {
return terr
}
Expand Down
2 changes: 1 addition & 1 deletion internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
return internalServerError("Unable to get SMS provider").WithInternalError(terr)
}
if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil {
if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil {
return unprocessableEntityError(ErrorCodeSMSSendFailed, "Error sending confirmation sms: %v", terr).WithInternalError(terr)
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
if terr != nil {
return internalServerError("Error finding SMS provider").WithInternalError(terr)
}
if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil {
if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil {
return internalServerError("Error sending phone change otp").WithInternalError(terr)
}
}
Expand Down

0 comments on commit e885fe3

Please sign in to comment.