-
Notifications
You must be signed in to change notification settings - Fork 378
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## What kind of change does this PR introduce? * Adds ip-based rate limiting on all endpoints that send OTPs either through email or phone with the config `GOTRUE_RATE_LIMIT_OTP` * IP-based rate limiting should always come before the shared limiter, so as to prevent the quota of the shared limiter from being consumed too quickly by the same ip-address
- Loading branch information
1 parent
f5c6fcd
commit 06464c0
Showing
3 changed files
with
183 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,8 @@ import ( | |
"testing" | ||
"time" | ||
|
||
"github.com/didip/tollbooth/v5" | ||
"github.com/didip/tollbooth/v5/limiter" | ||
jwt "github.com/golang-jwt/jwt" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
|
@@ -356,3 +358,135 @@ func TestTimeoutResponseWriter(t *testing.T) { | |
|
||
require.Equal(t, w1.Result(), w2.Result()) | ||
} | ||
|
||
func (ts *MiddlewareTestSuite) TestLimitHandler() { | ||
ts.Config.RateLimitHeader = "X-Rate-Limit" | ||
lmt := tollbooth.NewLimiter(5, &limiter.ExpirableOptions{ | ||
DefaultExpirationTTL: time.Hour, | ||
}) | ||
|
||
okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
b, _ := json.Marshal(map[string]interface{}{"message": "ok"}) | ||
w.Write([]byte(b)) | ||
}) | ||
|
||
for i := 0; i < 5; i++ { | ||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) | ||
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") | ||
w := httptest.NewRecorder() | ||
ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req) | ||
require.Equal(ts.T(), http.StatusOK, w.Code) | ||
|
||
var data map[string]interface{} | ||
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) | ||
require.Equal(ts.T(), "ok", data["message"]) | ||
} | ||
|
||
// 6th request should fail and return a rate limit exceeded error | ||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) | ||
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") | ||
w := httptest.NewRecorder() | ||
ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req) | ||
require.Equal(ts.T(), http.StatusTooManyRequests, w.Code) | ||
} | ||
|
||
func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() { | ||
// setup config for shared limiter and ip-based limiter to work | ||
ts.Config.RateLimitHeader = "X-Rate-Limit" | ||
ts.Config.External.Email.Enabled = true | ||
ts.Config.External.Phone.Enabled = true | ||
ts.Config.Mailer.Autoconfirm = false | ||
ts.Config.Sms.Autoconfirm = false | ||
|
||
ipBasedLimiter := func(max float64) *limiter.Limiter { | ||
return tollbooth.NewLimiter(max, &limiter.ExpirableOptions{ | ||
DefaultExpirationTTL: time.Hour, | ||
}) | ||
} | ||
|
||
okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
}) | ||
|
||
cases := []struct { | ||
desc string | ||
sharedLimiterConfig *conf.GlobalConfiguration | ||
ipBasedLimiterConfig float64 | ||
body map[string]interface{} | ||
expectedErrorCode string | ||
}{ | ||
{ | ||
desc: "Exceed ip-based rate limit before shared limiter", | ||
sharedLimiterConfig: &conf.GlobalConfiguration{ | ||
RateLimitEmailSent: 10, | ||
RateLimitSmsSent: 10, | ||
}, | ||
ipBasedLimiterConfig: 1, | ||
body: map[string]interface{}{ | ||
"email": "[email protected]", | ||
}, | ||
expectedErrorCode: ErrorCodeOverRequestRateLimit, | ||
}, | ||
{ | ||
desc: "Exceed email shared limiter", | ||
sharedLimiterConfig: &conf.GlobalConfiguration{ | ||
RateLimitEmailSent: 1, | ||
RateLimitSmsSent: 1, | ||
}, | ||
ipBasedLimiterConfig: 10, | ||
body: map[string]interface{}{ | ||
"email": "[email protected]", | ||
}, | ||
expectedErrorCode: ErrorCodeOverEmailSendRateLimit, | ||
}, | ||
{ | ||
desc: "Exceed sms shared limiter", | ||
sharedLimiterConfig: &conf.GlobalConfiguration{ | ||
RateLimitEmailSent: 1, | ||
RateLimitSmsSent: 1, | ||
}, | ||
ipBasedLimiterConfig: 10, | ||
body: map[string]interface{}{ | ||
"phone": "123456789", | ||
}, | ||
expectedErrorCode: ErrorCodeOverSMSSendRateLimit, | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
ts.Run(c.desc, func() { | ||
ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent | ||
ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent | ||
lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig)) | ||
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler() | ||
|
||
// get the minimum amount to reach the threshold just before the rate limit is exceeded | ||
threshold := min(c.sharedLimiterConfig.RateLimitEmailSent, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig) | ||
for i := 0; i < int(threshold); i++ { | ||
var buffer bytes.Buffer | ||
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) | ||
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) | ||
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") | ||
|
||
w := httptest.NewRecorder() | ||
lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req) | ||
require.Equal(ts.T(), http.StatusOK, w.Code) | ||
} | ||
|
||
var buffer bytes.Buffer | ||
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) | ||
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) | ||
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") | ||
|
||
// check if the rate limit is exceeded with the expected error code | ||
w := httptest.NewRecorder() | ||
lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req) | ||
require.Equal(ts.T(), http.StatusTooManyRequests, w.Code) | ||
|
||
var data map[string]interface{} | ||
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) | ||
require.Equal(ts.T(), c.expectedErrorCode, data["error_code"]) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters