diff --git a/internal/controlplane/handlers_githubwebhooks_test.go b/internal/controlplane/handlers_githubwebhooks_test.go index 4b3c90aa68..bfec3ffbdb 100644 --- a/internal/controlplane/handlers_githubwebhooks_test.go +++ b/internal/controlplane/handlers_githubwebhooks_test.go @@ -21,8 +21,10 @@ import ( "database/sql" "encoding/json" "fmt" + "io" "net/http" "os" + "strings" "testing" "time" @@ -436,6 +438,73 @@ func (s *UnitTestSuite) TestNoopWebhookHandler() { assert.Equal(t, http.StatusOK, resp.StatusCode, "unexpected status code") } +func (s *UnitTestSuite) TestHandleWebHookWithTooLargeRequest() { + t := s.T() + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := mockdb.NewMockStore(ctrl) + srv, evt := newDefaultServer(t, mockStore, nil) + defer evt.Close() + + pq := testqueue.NewPassthroughQueue(t) + queued := pq.GetQueue() + + evt.Register(events.TopicQueueEntityEvaluate, pq.Pass) + + go func() { + err := evt.Run(context.Background()) + require.NoError(t, err, "failed to run eventer") + }() + + <-evt.Running() + + hook := withMaxSizeMiddleware(srv.HandleGitHubWebHook()) + port, err := rand.GetRandomPort() + if err != nil { + t.Fatal(err) + } + addr := fmt.Sprintf("localhost:%d", port) + server := &http.Server{ + Addr: fmt.Sprintf(":%d", port), + Handler: hook, + ReadHeaderTimeout: 1 * time.Second, + } + go server.ListenAndServe() + + event := github.PackageEvent{ + Action: github.String("published"), + Repo: &github.Repository{ + ID: github.Int64(12345), + Name: github.String("stacklok/minder"), + }, + Org: &github.Organization{ + Login: github.String("stacklok"), + }, + } + packageJson, err := json.Marshal(event) + require.NoError(t, err, "failed to marshal package event") + + maliciousBody := strings.NewReader(strings.Repeat("1337", 1000000000)) + maliciousBodyReader := io.MultiReader(maliciousBody, maliciousBody, maliciousBody, maliciousBody, maliciousBody) + _ = packageJson + + client := &http.Client{} + req, err := http.NewRequest("POST", fmt.Sprintf("http://%s", addr), maliciousBodyReader) + require.NoError(t, err, "failed to create request") + + req.Header.Add("X-GitHub-Event", "meta") + req.Header.Add("X-GitHub-Delivery", "12345") + req.Header.Add("Content-Type", "application/json") + resp, err := httpDoWithRetry(client, req) + require.NoError(t, err, "failed to make request") + // We expect OK since we don't want to leak information about registered repositories + require.Equal(t, http.StatusBadRequest, resp.StatusCode, "unexpected status code") + assert.Len(t, queued, 0) +} + func TestAll(t *testing.T) { t.Parallel() diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index 1444905e0c..3c4fc31f4d 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -70,6 +70,10 @@ const metricsPath = "/metrics" var ( readHeaderTimeout = 2 * time.Second + + // RequestBodyMaxBytes is the maximum number of bytes that can be read from a request body + // We limit to 2MB for now + RequestBodyMaxBytes int64 = 2 << 20 ) // Server represents the controlplane server @@ -316,10 +320,10 @@ func (s *Server) StartHTTPServer(ctx context.Context) error { return fmt.Errorf("failed to register GitHub App callback handler: %w", err) } - mux.Handle("/", s.handlerWithHTTPMiddleware(gwmux)) - mux.Handle("/api/v1/webhook/", mw(s.HandleGitHubWebHook())) - mux.Handle("/api/v1/ghapp/", mw(s.HandleGitHubAppWebhook())) - mux.Handle("/api/v1/gh-marketplace/", mw(s.NoopWebhookHandler())) + mux.Handle("/", withMaxSizeMiddleware(s.handlerWithHTTPMiddleware(gwmux))) + mux.Handle("/api/v1/webhook/", mw(withMaxSizeMiddleware(s.HandleGitHubWebHook()))) + mux.Handle("/api/v1/ghapp/", mw(withMaxSizeMiddleware(s.HandleGitHubAppWebhook()))) + mux.Handle("/api/v1/gh-marketplace/", mw(withMaxSizeMiddleware(s.NoopWebhookHandler()))) mux.Handle("/static/", fs) errch := make(chan error) @@ -451,3 +455,10 @@ func shutdownHandler(component string, sdf shutdowner) { log.Fatal().Msgf("error shutting down '%s': %+v", component, err) } } + +func withMaxSizeMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, RequestBodyMaxBytes) + h.ServeHTTP(w, r) + }) +}