diff --git a/pkg/frontend/frontend_test.go b/pkg/frontend/frontend_test.go index a9f1b3e75e..7d26c30777 100644 --- a/pkg/frontend/frontend_test.go +++ b/pkg/frontend/frontend_test.go @@ -6,6 +6,7 @@ package frontend import ( + "bytes" "context" "fmt" "net/http" @@ -26,6 +27,7 @@ import ( "github.com/grafana/dskit/services" "github.com/grafana/dskit/test" "github.com/grafana/dskit/user" + "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" @@ -34,12 +36,18 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + "github.com/grafana/pyroscope/api/gen/proto/go/querier/v1/querierv1connect" + typesv1 "github.com/grafana/pyroscope/api/gen/proto/go/types/v1" + connectapi "github.com/grafana/pyroscope/pkg/api/connect" "github.com/grafana/pyroscope/pkg/frontend/frontendpb" "github.com/grafana/pyroscope/pkg/frontend/frontendpb/frontendpbconnect" "github.com/grafana/pyroscope/pkg/querier/stats" + "github.com/grafana/pyroscope/pkg/querier/worker" + "github.com/grafana/pyroscope/pkg/scheduler" "github.com/grafana/pyroscope/pkg/scheduler/schedulerdiscovery" "github.com/grafana/pyroscope/pkg/scheduler/schedulerpb" "github.com/grafana/pyroscope/pkg/scheduler/schedulerpb/schedulerpbconnect" + "github.com/grafana/pyroscope/pkg/util/connectgrpc" "github.com/grafana/pyroscope/pkg/util/httpgrpc" "github.com/grafana/pyroscope/pkg/util/servicediscovery" "github.com/grafana/pyroscope/pkg/validation" @@ -51,14 +59,8 @@ func setupFrontend(t *testing.T, reg prometheus.Registerer, schedulerReplyFunc f return setupFrontendWithConcurrencyAndServerOptions(t, reg, schedulerReplyFunc, testFrontendWorkerConcurrency) } -func setupFrontendWithConcurrencyAndServerOptions(t *testing.T, reg prometheus.Registerer, schedulerReplyFunc func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend, concurrency int) (*Frontend, *mockScheduler) { - s := httptest.NewUnstartedServer(nil) - mux := mux.NewRouter() - s.Config.Handler = h2c.NewHandler(mux, &http2.Server{}) - - s.Start() - - u, err := url.Parse(s.URL) +func cfgFromURL(t *testing.T, urlS string) Config { + u, err := url.Parse(urlS) require.NoError(t, err) port, err := strconv.Atoi(u.Port()) @@ -67,9 +69,20 @@ func setupFrontendWithConcurrencyAndServerOptions(t *testing.T, reg prometheus.R cfg := Config{} flagext.DefaultValues(&cfg) cfg.SchedulerAddress = u.Hostname() + ":" + u.Port() - cfg.WorkerConcurrency = concurrency cfg.Addr = u.Hostname() cfg.Port = port + return cfg +} + +func setupFrontendWithConcurrencyAndServerOptions(t *testing.T, reg prometheus.Registerer, schedulerReplyFunc func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend, concurrency int) (*Frontend, *mockScheduler) { + s := httptest.NewUnstartedServer(nil) + mux := mux.NewRouter() + s.Config.Handler = h2c.NewHandler(mux, &http2.Server{}) + + s.Start() + + cfg := cfgFromURL(t, s.URL) + cfg.WorkerConcurrency = concurrency logger := log.NewLogfmtLogger(os.Stdout) f, err := NewFrontend(cfg, validation.MockLimits{MaxQueryParallelismValue: 1}, logger, reg) @@ -181,6 +194,137 @@ func TestFrontendRequestsPerWorkerMetric(t *testing.T) { require.NoError(t, testutil.GatherAndCompare(reg, strings.NewReader(expectedMetrics), "pyroscope_query_frontend_workers_enqueued_requests_total")) } +func newFakeQuerierGRPCHandler() connectgrpc.GRPCHandler { + q := &fakeQuerier{} + mux := http.NewServeMux() + mux.Handle(querierv1connect.NewQuerierServiceHandler(q, connectapi.DefaultHandlerOptions()...)) + return connectgrpc.NewHandler(mux) +} + +type fakeQuerier struct { + querierv1connect.QuerierServiceHandler +} + +func (f *fakeQuerier) LabelNames(ctx context.Context, req *connect.Request[typesv1.LabelNamesRequest]) (*connect.Response[typesv1.LabelNamesResponse], error) { + return connect.NewResponse(&typesv1.LabelNamesResponse{ + Names: []string{"i", "have", "labels"}, + }), nil +} + +func headerToSlice(t testing.TB, header http.Header) []string { + buf := new(bytes.Buffer) + excludeHeaders := map[string]bool{"Content-Length": true, "Date": true} + require.NoError(t, header.WriteSubset(buf, excludeHeaders)) + sl := strings.Split(strings.ReplaceAll(buf.String(), "\r\n", "\n"), "\n") + if len(sl) > 0 && sl[len(sl)-1] == "" { + sl = sl[:len(sl)-1] + } + return sl +} + +// TestFrontendFullRoundtrip tests the full roundtrip of a request from the frontend to a fake querier and back, with using an actual scheduler. +func TestFrontendFullRoundtrip(t *testing.T) { + var ( + logger = log.NewNopLogger() + reg = prometheus.NewRegistry() + tenant = "tenant-a" + ) + if testing.Verbose() { + logger = log.NewLogfmtLogger(os.Stderr) + } + + // create server for frontend and scheduler + mux := mux.NewRouter() + // inject a span/tenant into the context + mux.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := user.InjectOrgID(r.Context(), tenant) + _, ctx = opentracing.StartSpanFromContext(ctx, "test") + next.ServeHTTP(w, r.WithContext(ctx)) + }) + }) + s := httptest.NewServer(h2c.NewHandler(mux, &http2.Server{})) + defer s.Close() + + // initialize the scheduler + schedCfg := scheduler.Config{} + flagext.DefaultValues(&schedCfg) + sched, err := scheduler.NewScheduler(schedCfg, validation.MockLimits{}, logger, reg) + require.NoError(t, err) + schedulerpbconnect.RegisterSchedulerForFrontendHandler(mux, sched) + schedulerpbconnect.RegisterSchedulerForQuerierHandler(mux, sched) + + // initialize the frontend + fCfg := cfgFromURL(t, s.URL) + f, err := NewFrontend(fCfg, validation.MockLimits{MaxQueryParallelismValue: 1}, logger, reg) + require.NoError(t, err) + frontendpbconnect.RegisterFrontendForQuerierHandler(mux, f) // probably not needed + querierv1connect.RegisterQuerierServiceHandler(mux, f) + + // create a querier worker + qWorkerCfg := worker.Config{} + flagext.DefaultValues(&qWorkerCfg) + qWorkerCfg.SchedulerAddress = fCfg.SchedulerAddress + qWorker, err := worker.NewQuerierWorker(qWorkerCfg, newFakeQuerierGRPCHandler(), log.NewLogfmtLogger(os.Stderr), prometheus.NewRegistry()) + require.NoError(t, err) + + // start services + svc, err := services.NewManager(sched, f, qWorker) + require.NoError(t, err) + require.NoError(t, svc.StartAsync(context.Background())) + require.NoError(t, svc.AwaitHealthy(context.Background())) + defer func() { + svc.StopAsync() + require.NoError(t, svc.AwaitStopped(context.Background())) + }() + + t.Run("using protocol grpc", func(t *testing.T) { + client := querierv1connect.NewQuerierServiceClient(http.DefaultClient, s.URL, connect.WithGRPC()) + + resp, err := client.LabelNames(context.Background(), connect.NewRequest(&typesv1.LabelNamesRequest{})) + require.NoError(t, err) + + require.Equal(t, []string{"i", "have", "labels"}, resp.Msg.Names) + + assert.Equal(t, []string{ + "Content-Type: application/grpc", + "Grpc-Accept-Encoding: gzip", + "Grpc-Encoding: gzip", + }, headerToSlice(t, resp.Header())) + }) + + t.Run("using protocol grpc-web", func(t *testing.T) { + client := querierv1connect.NewQuerierServiceClient(http.DefaultClient, s.URL, connect.WithGRPCWeb()) + + resp, err := client.LabelNames(context.Background(), connect.NewRequest(&typesv1.LabelNamesRequest{})) + require.NoError(t, err) + + require.Equal(t, []string{"i", "have", "labels"}, resp.Msg.Names) + + assert.Equal(t, []string{ + "Content-Type: application/grpc-web+proto", + "Grpc-Accept-Encoding: gzip", + "Grpc-Encoding: gzip", + }, headerToSlice(t, resp.Header())) + }) + + t.Run("using protocol json", func(t *testing.T) { + client := querierv1connect.NewQuerierServiceClient(http.DefaultClient, s.URL, connect.WithProtoJSON()) + + resp, err := client.LabelNames(context.Background(), connect.NewRequest(&typesv1.LabelNamesRequest{})) + require.NoError(t, err) + + require.Equal(t, []string{"i", "have", "labels"}, resp.Msg.Names) + + assert.Equal(t, []string{ + "Accept-Encoding: gzip", + "Content-Encoding: gzip", + "Content-Type: application/json", + }, headerToSlice(t, resp.Header())) + }) + +} + func TestFrontendRetryEnqueue(t *testing.T) { // Frontend uses worker concurrency to compute number of retries. We use one less failure. failures := atomic.NewInt64(testFrontendWorkerConcurrency - 1) diff --git a/pkg/util/connectgrpc/connectgrpc.go b/pkg/util/connectgrpc/connectgrpc.go index b0d333c9aa..1634d21d77 100644 --- a/pkg/util/connectgrpc/connectgrpc.go +++ b/pkg/util/connectgrpc/connectgrpc.go @@ -166,12 +166,30 @@ func removeContentHeaders(h http.Header) http.Header { return h } +// filterHeader filters headers, which would expose details about the implementation details of the connectgrpc implementation +func filterHeader(name string) bool { + if strings.ToLower(name) == "content-type" { + return true + } + if strings.ToLower(name) == "accept-encoding" { + return true + } + if strings.ToLower(name) == "content-encoding" { + return true + } + return false +} + func decodeResponse[Resp any](r *httpgrpc.HTTPResponse) (*connect.Response[Resp], error) { if err := decompressResponse(r); err != nil { return nil, err } resp := &connect.Response[Resp]{Msg: new(Resp)} for _, h := range r.Headers { + if filterHeader(h.Key) { + continue + } + for _, v := range h.Values { resp.Header().Add(h.Key, v) } diff --git a/pkg/util/connectgrpc/connectgrpc_test.go b/pkg/util/connectgrpc/connectgrpc_test.go index b8b71a0980..3c810f424d 100644 --- a/pkg/util/connectgrpc/connectgrpc_test.go +++ b/pkg/util/connectgrpc/connectgrpc_test.go @@ -1,6 +1,7 @@ package connectgrpc import ( + "bytes" "context" "net/http" "net/http/httptest" @@ -40,6 +41,17 @@ func (m *mockRoundTripper) RoundTripGRPC(_ context.Context, req *httpgrpc.HTTPRe return m.resp, nil } +func headerToSlice(t testing.TB, header http.Header) []string { + buf := new(bytes.Buffer) + excludeHeaders := map[string]bool{"Content-Length": true, "Date": true} + require.NoError(t, header.WriteSubset(buf, excludeHeaders)) + sl := strings.Split(strings.ReplaceAll(buf.String(), "\r\n", "\n"), "\n") + if len(sl) > 0 && sl[len(sl)-1] == "" { + sl = sl[:len(sl)-1] + } + return sl +} + func Test_RoundTripUnary(t *testing.T) { request := func(t *testing.T) *connect.Request[typesv1.LabelValuesRequest] { server := httptest.NewUnstartedServer(nil) @@ -64,8 +76,15 @@ func Test_RoundTripUnary(t *testing.T) { t.Run("HTTP request can trip GRPC", func(t *testing.T) { req := request(t) - m := &mockRoundTripper{resp: &httpgrpc.HTTPResponse{Code: 200}} - _, err := RoundTripUnary[typesv1.LabelValuesRequest, typesv1.LabelValuesResponse](context.Background(), m, req) + m := &mockRoundTripper{resp: &httpgrpc.HTTPResponse{ + Code: 200, + Headers: []*httpgrpc.Header{ + {Key: "Content-Type", Values: []string{"application/proto"}}, + {Key: "X-My-App", Values: []string{"foobar"}}, + }, + }} + + resp, err := RoundTripUnary[typesv1.LabelValuesRequest, typesv1.LabelValuesResponse](context.Background(), m, req) require.NoError(t, err) require.Equal(t, "POST", m.req.Method) require.Equal(t, "/querier.v1.QuerierService/LabelValues", m.req.Url) @@ -79,6 +98,10 @@ func Test_RoundTripUnary(t *testing.T) { decoded, err := decodeRequest[typesv1.LabelValuesRequest](m.req) require.NoError(t, err) require.Equal(t, req.Msg.Name, decoded.Msg.Name) + + // ensure no headers leak + require.Equal(t, []string{"X-My-App: foobar"}, headerToSlice(t, resp.Header())) + }) t.Run("HTTP request URL can be overridden", func(t *testing.T) { diff --git a/pkg/validation/testutil.go b/pkg/validation/testutil.go index 13e04f9a06..def8bb7743 100644 --- a/pkg/validation/testutil.go +++ b/pkg/validation/testutil.go @@ -29,6 +29,8 @@ type MockLimits struct { MaxProfileStacktraceDepthValue int MaxProfileStacktraceSampleLabelsValue int MaxProfileSymbolValueLengthValue int + + MaxQueriersPerTenantValue int } func (m MockLimits) QuerySplitDuration(string) time.Duration { return m.QuerySplitDurationValue } @@ -70,6 +72,10 @@ func (m MockLimits) MaxProfileSymbolValueLength(userID string) int { return m.MaxProfileSymbolValueLengthValue } +func (m MockLimits) MaxQueriersPerTenant(_ string) int { + return m.MaxQueriersPerTenantValue +} + func (m MockLimits) RejectOlderThan(userID string) time.Duration { return m.RejectOlderThanValue }