diff --git a/internal/api/api.go b/internal/api/api.go index 8613a05dc..6e49e8a19 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -136,9 +136,14 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite) r.With(sharedLimiter).With(api.verifyCaptcha).Route("/signup", func(r *router) { // rate limit per hour - limiter := tollbooth.NewLimiter(api.config.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{ + limitAnonymousSignIns := tollbooth.NewLimiter(api.config.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{ DefaultExpirationTTL: time.Hour, }).SetBurst(int(api.config.RateLimitAnonymousUsers)).SetMethods([]string{"POST"}) + + limitSignups := tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + r.Post("/", func(w http.ResponseWriter, r *http.Request) error { params := &SignupParams{} if err := retrieveRequestParams(r, params); err != nil { @@ -148,19 +153,50 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne if !api.config.External.AnonymousUsers.Enabled { return unprocessableEntityError(ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled") } - if _, err := api.limitHandler(limiter)(w, r); err != nil { + if _, err := api.limitHandler(limitAnonymousSignIns)(w, r); err != nil { return err } return api.SignupAnonymously(w, r) } + + // apply ip-based rate limiting on otps + if _, err := api.limitHandler(limitSignups)(w, r); err != nil { + return err + } + // apply shared rate limiting on email / phone + if _, err := sharedLimiter(w, r); err != nil { + return err + } return api.Signup(w, r) }) }) - r.With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover) - r.With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend) - r.With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink) + r.With(api.limitHandler( + // Allow requests at the specified rate per 5 minutes + tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30), + )).With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover) - r.With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp) + r.With(api.limitHandler( + // Allow requests at the specified rate per 5 minutes + tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30), + )).With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend) + + r.With(api.limitHandler( + // Allow requests at the specified rate per 5 minutes + tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30), + )).With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink) + + r.With(api.limitHandler( + // Allow requests at the specified rate per 5 minutes + tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30), + )).With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp) r.With(api.limitHandler( // Allow requests at the specified rate per 5 minutes. @@ -187,7 +223,12 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.With(api.requireAuthentication).Route("/user", func(r *router) { r.Get("/", api.UserGet) - r.With(sharedLimiter).Put("/", api.UserUpdate) + r.With(api.limitHandler( + // Allow requests at the specified rate per 5 minutes + tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30), + )).With(sharedLimiter).Put("/", api.UserUpdate) r.Route("/identities", func(r *router) { r.Use(api.requireManualLinkingEnabled) diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index a9d908c32..4d0e327f3 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + "github.com/didip/tollbooth/v5" + "github.com/didip/tollbooth/v5/limiter" jwt "github.com/golang-jwt/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -356,3 +358,135 @@ func TestTimeoutResponseWriter(t *testing.T) { require.Equal(t, w1.Result(), w2.Result()) } + +func (ts *MiddlewareTestSuite) TestLimitHandler() { + ts.Config.RateLimitHeader = "X-Rate-Limit" + lmt := tollbooth.NewLimiter(5, &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }) + + okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + b, _ := json.Marshal(map[string]interface{}{"message": "ok"}) + w.Write([]byte(b)) + }) + + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") + w := httptest.NewRecorder() + ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var data map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), "ok", data["message"]) + } + + // 6th request should fail and return a rate limit exceeded error + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") + w := httptest.NewRecorder() + ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusTooManyRequests, w.Code) +} + +func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() { + // setup config for shared limiter and ip-based limiter to work + ts.Config.RateLimitHeader = "X-Rate-Limit" + ts.Config.External.Email.Enabled = true + ts.Config.External.Phone.Enabled = true + ts.Config.Mailer.Autoconfirm = false + ts.Config.Sms.Autoconfirm = false + + ipBasedLimiter := func(max float64) *limiter.Limiter { + return tollbooth.NewLimiter(max, &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }) + } + + okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + cases := []struct { + desc string + sharedLimiterConfig *conf.GlobalConfiguration + ipBasedLimiterConfig float64 + body map[string]interface{} + expectedErrorCode string + }{ + { + desc: "Exceed ip-based rate limit before shared limiter", + sharedLimiterConfig: &conf.GlobalConfiguration{ + RateLimitEmailSent: 10, + RateLimitSmsSent: 10, + }, + ipBasedLimiterConfig: 1, + body: map[string]interface{}{ + "email": "foo@example.com", + }, + expectedErrorCode: ErrorCodeOverRequestRateLimit, + }, + { + desc: "Exceed email shared limiter", + sharedLimiterConfig: &conf.GlobalConfiguration{ + RateLimitEmailSent: 1, + RateLimitSmsSent: 1, + }, + ipBasedLimiterConfig: 10, + body: map[string]interface{}{ + "email": "foo@example.com", + }, + expectedErrorCode: ErrorCodeOverEmailSendRateLimit, + }, + { + desc: "Exceed sms shared limiter", + sharedLimiterConfig: &conf.GlobalConfiguration{ + RateLimitEmailSent: 1, + RateLimitSmsSent: 1, + }, + ipBasedLimiterConfig: 10, + body: map[string]interface{}{ + "phone": "123456789", + }, + expectedErrorCode: ErrorCodeOverSMSSendRateLimit, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent + ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent + lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig)) + sharedLimiter := ts.API.limitEmailOrPhoneSentHandler() + + // get the minimum amount to reach the threshold just before the rate limit is exceeded + threshold := min(c.sharedLimiterConfig.RateLimitEmailSent, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig) + for i := 0; i < int(threshold); i++ { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) + req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") + + w := httptest.NewRecorder() + lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) + req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") + + // check if the rate limit is exceeded with the expected error code + w := httptest.NewRecorder() + lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + var data map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), c.expectedErrorCode, data["error_code"]) + }) + } +} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index d3ba720a0..35024ea52 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -223,6 +223,7 @@ type GlobalConfiguration struct { RateLimitTokenRefresh float64 `split_words:"true" default:"150"` RateLimitSso float64 `split_words:"true" default:"30"` RateLimitAnonymousUsers float64 `split_words:"true" default:"30"` + RateLimitOtp float64 `split_words:"true" default:"30"` SiteURL string `json:"site_url" split_words:"true" required:"true"` URIAllowList []string `json:"uri_allow_list" split_words:"true"`