From e74f3b8af0ea977363f52c095bcc960a9efd29b4 Mon Sep 17 00:00:00 2001 From: Christian Simon Date: Wed, 19 Jun 2024 13:48:55 +0100 Subject: [PATCH] Fix frontend header handling (#3363) The query-frontend accidentally leaked content-type and content-encoding information from the underlying connectgrpc implementation, which is used to schedule queries on to the workers (queriers). This will filter those out, so they won't be passed on through the frontend. This could be observed by errors like this: ``` $ profilecli query series level=info msg="query series from querier" url=http://localhost:8080 from=2024-06-18T15:46:19.932471+01:00 to=2024-06-18T16:46:19.932472+01:00 labelNames=[] error: failed to query: unknown: invalid content-type: "application/grpc,application/proto"; expecting "application/grpc" exit status 1 ``` Which themselves were produced only because envoy (involved on our GC ingress), converted headers from this: ``` content-type: application/grpc content-type: application/proto ``` to ``` content-type: application/grpc,application/proto ``` This then triggered a validation recently added to connect-go: https://github.com/connectrpc/connect-go/pull/679 --- pkg/frontend/frontend_test.go | 162 +++++++++++++++++++++-- pkg/util/connectgrpc/connectgrpc.go | 18 +++ pkg/util/connectgrpc/connectgrpc_test.go | 27 +++- pkg/validation/testutil.go | 6 + 4 files changed, 202 insertions(+), 11 deletions(-) diff --git a/pkg/frontend/frontend_test.go b/pkg/frontend/frontend_test.go index a9f1b3e75e..7d26c30777 100644 --- a/pkg/frontend/frontend_test.go +++ b/pkg/frontend/frontend_test.go @@ -6,6 +6,7 @@ package frontend import ( + "bytes" "context" "fmt" "net/http" @@ -26,6 +27,7 @@ import ( "github.com/grafana/dskit/services" "github.com/grafana/dskit/test" "github.com/grafana/dskit/user" + "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" @@ -34,12 +36,18 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + "github.com/grafana/pyroscope/api/gen/proto/go/querier/v1/querierv1connect" + typesv1 "github.com/grafana/pyroscope/api/gen/proto/go/types/v1" + connectapi "github.com/grafana/pyroscope/pkg/api/connect" "github.com/grafana/pyroscope/pkg/frontend/frontendpb" "github.com/grafana/pyroscope/pkg/frontend/frontendpb/frontendpbconnect" "github.com/grafana/pyroscope/pkg/querier/stats" + "github.com/grafana/pyroscope/pkg/querier/worker" + "github.com/grafana/pyroscope/pkg/scheduler" "github.com/grafana/pyroscope/pkg/scheduler/schedulerdiscovery" "github.com/grafana/pyroscope/pkg/scheduler/schedulerpb" "github.com/grafana/pyroscope/pkg/scheduler/schedulerpb/schedulerpbconnect" + "github.com/grafana/pyroscope/pkg/util/connectgrpc" "github.com/grafana/pyroscope/pkg/util/httpgrpc" "github.com/grafana/pyroscope/pkg/util/servicediscovery" "github.com/grafana/pyroscope/pkg/validation" @@ -51,14 +59,8 @@ func setupFrontend(t *testing.T, reg prometheus.Registerer, schedulerReplyFunc f return setupFrontendWithConcurrencyAndServerOptions(t, reg, schedulerReplyFunc, testFrontendWorkerConcurrency) } -func setupFrontendWithConcurrencyAndServerOptions(t *testing.T, reg prometheus.Registerer, schedulerReplyFunc func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend, concurrency int) (*Frontend, *mockScheduler) { - s := httptest.NewUnstartedServer(nil) - mux := mux.NewRouter() - s.Config.Handler = h2c.NewHandler(mux, &http2.Server{}) - - s.Start() - - u, err := url.Parse(s.URL) +func cfgFromURL(t *testing.T, urlS string) Config { + u, err := url.Parse(urlS) require.NoError(t, err) port, err := strconv.Atoi(u.Port()) @@ -67,9 +69,20 @@ func setupFrontendWithConcurrencyAndServerOptions(t *testing.T, reg prometheus.R cfg := Config{} flagext.DefaultValues(&cfg) cfg.SchedulerAddress = u.Hostname() + ":" + u.Port() - cfg.WorkerConcurrency = concurrency cfg.Addr = u.Hostname() cfg.Port = port + return cfg +} + +func setupFrontendWithConcurrencyAndServerOptions(t *testing.T, reg prometheus.Registerer, schedulerReplyFunc func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend, concurrency int) (*Frontend, *mockScheduler) { + s := httptest.NewUnstartedServer(nil) + mux := mux.NewRouter() + s.Config.Handler = h2c.NewHandler(mux, &http2.Server{}) + + s.Start() + + cfg := cfgFromURL(t, s.URL) + cfg.WorkerConcurrency = concurrency logger := log.NewLogfmtLogger(os.Stdout) f, err := NewFrontend(cfg, validation.MockLimits{MaxQueryParallelismValue: 1}, logger, reg) @@ -181,6 +194,137 @@ func TestFrontendRequestsPerWorkerMetric(t *testing.T) { require.NoError(t, testutil.GatherAndCompare(reg, strings.NewReader(expectedMetrics), "pyroscope_query_frontend_workers_enqueued_requests_total")) } +func newFakeQuerierGRPCHandler() connectgrpc.GRPCHandler { + q := &fakeQuerier{} + mux := http.NewServeMux() + mux.Handle(querierv1connect.NewQuerierServiceHandler(q, connectapi.DefaultHandlerOptions()...)) + return connectgrpc.NewHandler(mux) +} + +type fakeQuerier struct { + querierv1connect.QuerierServiceHandler +} + +func (f *fakeQuerier) LabelNames(ctx context.Context, req *connect.Request[typesv1.LabelNamesRequest]) (*connect.Response[typesv1.LabelNamesResponse], error) { + return connect.NewResponse(&typesv1.LabelNamesResponse{ + Names: []string{"i", "have", "labels"}, + }), nil +} + +func headerToSlice(t testing.TB, header http.Header) []string { + buf := new(bytes.Buffer) + excludeHeaders := map[string]bool{"Content-Length": true, "Date": true} + require.NoError(t, header.WriteSubset(buf, excludeHeaders)) + sl := strings.Split(strings.ReplaceAll(buf.String(), "\r\n", "\n"), "\n") + if len(sl) > 0 && sl[len(sl)-1] == "" { + sl = sl[:len(sl)-1] + } + return sl +} + +// TestFrontendFullRoundtrip tests the full roundtrip of a request from the frontend to a fake querier and back, with using an actual scheduler. +func TestFrontendFullRoundtrip(t *testing.T) { + var ( + logger = log.NewNopLogger() + reg = prometheus.NewRegistry() + tenant = "tenant-a" + ) + if testing.Verbose() { + logger = log.NewLogfmtLogger(os.Stderr) + } + + // create server for frontend and scheduler + mux := mux.NewRouter() + // inject a span/tenant into the context + mux.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := user.InjectOrgID(r.Context(), tenant) + _, ctx = opentracing.StartSpanFromContext(ctx, "test") + next.ServeHTTP(w, r.WithContext(ctx)) + }) + }) + s := httptest.NewServer(h2c.NewHandler(mux, &http2.Server{})) + defer s.Close() + + // initialize the scheduler + schedCfg := scheduler.Config{} + flagext.DefaultValues(&schedCfg) + sched, err := scheduler.NewScheduler(schedCfg, validation.MockLimits{}, logger, reg) + require.NoError(t, err) + schedulerpbconnect.RegisterSchedulerForFrontendHandler(mux, sched) + schedulerpbconnect.RegisterSchedulerForQuerierHandler(mux, sched) + + // initialize the frontend + fCfg := cfgFromURL(t, s.URL) + f, err := NewFrontend(fCfg, validation.MockLimits{MaxQueryParallelismValue: 1}, logger, reg) + require.NoError(t, err) + frontendpbconnect.RegisterFrontendForQuerierHandler(mux, f) // probably not needed + querierv1connect.RegisterQuerierServiceHandler(mux, f) + + // create a querier worker + qWorkerCfg := worker.Config{} + flagext.DefaultValues(&qWorkerCfg) + qWorkerCfg.SchedulerAddress = fCfg.SchedulerAddress + qWorker, err := worker.NewQuerierWorker(qWorkerCfg, newFakeQuerierGRPCHandler(), log.NewLogfmtLogger(os.Stderr), prometheus.NewRegistry()) + require.NoError(t, err) + + // start services + svc, err := services.NewManager(sched, f, qWorker) + require.NoError(t, err) + require.NoError(t, svc.StartAsync(context.Background())) + require.NoError(t, svc.AwaitHealthy(context.Background())) + defer func() { + svc.StopAsync() + require.NoError(t, svc.AwaitStopped(context.Background())) + }() + + t.Run("using protocol grpc", func(t *testing.T) { + client := querierv1connect.NewQuerierServiceClient(http.DefaultClient, s.URL, connect.WithGRPC()) + + resp, err := client.LabelNames(context.Background(), connect.NewRequest(&typesv1.LabelNamesRequest{})) + require.NoError(t, err) + + require.Equal(t, []string{"i", "have", "labels"}, resp.Msg.Names) + + assert.Equal(t, []string{ + "Content-Type: application/grpc", + "Grpc-Accept-Encoding: gzip", + "Grpc-Encoding: gzip", + }, headerToSlice(t, resp.Header())) + }) + + t.Run("using protocol grpc-web", func(t *testing.T) { + client := querierv1connect.NewQuerierServiceClient(http.DefaultClient, s.URL, connect.WithGRPCWeb()) + + resp, err := client.LabelNames(context.Background(), connect.NewRequest(&typesv1.LabelNamesRequest{})) + require.NoError(t, err) + + require.Equal(t, []string{"i", "have", "labels"}, resp.Msg.Names) + + assert.Equal(t, []string{ + "Content-Type: application/grpc-web+proto", + "Grpc-Accept-Encoding: gzip", + "Grpc-Encoding: gzip", + }, headerToSlice(t, resp.Header())) + }) + + t.Run("using protocol json", func(t *testing.T) { + client := querierv1connect.NewQuerierServiceClient(http.DefaultClient, s.URL, connect.WithProtoJSON()) + + resp, err := client.LabelNames(context.Background(), connect.NewRequest(&typesv1.LabelNamesRequest{})) + require.NoError(t, err) + + require.Equal(t, []string{"i", "have", "labels"}, resp.Msg.Names) + + assert.Equal(t, []string{ + "Accept-Encoding: gzip", + "Content-Encoding: gzip", + "Content-Type: application/json", + }, headerToSlice(t, resp.Header())) + }) + +} + func TestFrontendRetryEnqueue(t *testing.T) { // Frontend uses worker concurrency to compute number of retries. We use one less failure. failures := atomic.NewInt64(testFrontendWorkerConcurrency - 1) diff --git a/pkg/util/connectgrpc/connectgrpc.go b/pkg/util/connectgrpc/connectgrpc.go index b0d333c9aa..1634d21d77 100644 --- a/pkg/util/connectgrpc/connectgrpc.go +++ b/pkg/util/connectgrpc/connectgrpc.go @@ -166,12 +166,30 @@ func removeContentHeaders(h http.Header) http.Header { return h } +// filterHeader filters headers, which would expose details about the implementation details of the connectgrpc implementation +func filterHeader(name string) bool { + if strings.ToLower(name) == "content-type" { + return true + } + if strings.ToLower(name) == "accept-encoding" { + return true + } + if strings.ToLower(name) == "content-encoding" { + return true + } + return false +} + func decodeResponse[Resp any](r *httpgrpc.HTTPResponse) (*connect.Response[Resp], error) { if err := decompressResponse(r); err != nil { return nil, err } resp := &connect.Response[Resp]{Msg: new(Resp)} for _, h := range r.Headers { + if filterHeader(h.Key) { + continue + } + for _, v := range h.Values { resp.Header().Add(h.Key, v) } diff --git a/pkg/util/connectgrpc/connectgrpc_test.go b/pkg/util/connectgrpc/connectgrpc_test.go index b8b71a0980..3c810f424d 100644 --- a/pkg/util/connectgrpc/connectgrpc_test.go +++ b/pkg/util/connectgrpc/connectgrpc_test.go @@ -1,6 +1,7 @@ package connectgrpc import ( + "bytes" "context" "net/http" "net/http/httptest" @@ -40,6 +41,17 @@ func (m *mockRoundTripper) RoundTripGRPC(_ context.Context, req *httpgrpc.HTTPRe return m.resp, nil } +func headerToSlice(t testing.TB, header http.Header) []string { + buf := new(bytes.Buffer) + excludeHeaders := map[string]bool{"Content-Length": true, "Date": true} + require.NoError(t, header.WriteSubset(buf, excludeHeaders)) + sl := strings.Split(strings.ReplaceAll(buf.String(), "\r\n", "\n"), "\n") + if len(sl) > 0 && sl[len(sl)-1] == "" { + sl = sl[:len(sl)-1] + } + return sl +} + func Test_RoundTripUnary(t *testing.T) { request := func(t *testing.T) *connect.Request[typesv1.LabelValuesRequest] { server := httptest.NewUnstartedServer(nil) @@ -64,8 +76,15 @@ func Test_RoundTripUnary(t *testing.T) { t.Run("HTTP request can trip GRPC", func(t *testing.T) { req := request(t) - m := &mockRoundTripper{resp: &httpgrpc.HTTPResponse{Code: 200}} - _, err := RoundTripUnary[typesv1.LabelValuesRequest, typesv1.LabelValuesResponse](context.Background(), m, req) + m := &mockRoundTripper{resp: &httpgrpc.HTTPResponse{ + Code: 200, + Headers: []*httpgrpc.Header{ + {Key: "Content-Type", Values: []string{"application/proto"}}, + {Key: "X-My-App", Values: []string{"foobar"}}, + }, + }} + + resp, err := RoundTripUnary[typesv1.LabelValuesRequest, typesv1.LabelValuesResponse](context.Background(), m, req) require.NoError(t, err) require.Equal(t, "POST", m.req.Method) require.Equal(t, "/querier.v1.QuerierService/LabelValues", m.req.Url) @@ -79,6 +98,10 @@ func Test_RoundTripUnary(t *testing.T) { decoded, err := decodeRequest[typesv1.LabelValuesRequest](m.req) require.NoError(t, err) require.Equal(t, req.Msg.Name, decoded.Msg.Name) + + // ensure no headers leak + require.Equal(t, []string{"X-My-App: foobar"}, headerToSlice(t, resp.Header())) + }) t.Run("HTTP request URL can be overridden", func(t *testing.T) { diff --git a/pkg/validation/testutil.go b/pkg/validation/testutil.go index 13e04f9a06..def8bb7743 100644 --- a/pkg/validation/testutil.go +++ b/pkg/validation/testutil.go @@ -29,6 +29,8 @@ type MockLimits struct { MaxProfileStacktraceDepthValue int MaxProfileStacktraceSampleLabelsValue int MaxProfileSymbolValueLengthValue int + + MaxQueriersPerTenantValue int } func (m MockLimits) QuerySplitDuration(string) time.Duration { return m.QuerySplitDurationValue } @@ -70,6 +72,10 @@ func (m MockLimits) MaxProfileSymbolValueLength(userID string) int { return m.MaxProfileSymbolValueLengthValue } +func (m MockLimits) MaxQueriersPerTenant(_ string) int { + return m.MaxQueriersPerTenantValue +} + func (m MockLimits) RejectOlderThan(userID string) time.Duration { return m.RejectOlderThanValue }