Skip to content

Commit

Permalink
/active_series: generate correct request shards for incoming GET
Browse files Browse the repository at this point in the history
…requests, handle gRPC errors (#7133)

* codes.Canceled is not an error either

* add test for GET-request based shards

* generate shard requests correctly

* read response body only if it's not nil

* add test for context propagation

* add org id to generated requests

* use wrapped grpc error

* add integration test

* add license header and build tag

* remove stray parens
  • Loading branch information
flxbk authored Jan 17, 2024
1 parent ca1e7bc commit 52c39fb
Show file tree
Hide file tree
Showing 8 changed files with 359 additions and 87 deletions.
82 changes: 82 additions & 0 deletions integration/e2emimir/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/gogo/protobuf/proto"
"github.com/golang/snappy"
"github.com/klauspost/compress/s2"
alertConfig "github.com/prometheus/alertmanager/config"
"github.com/prometheus/alertmanager/types"
promapi "github.com/prometheus/client_golang/api"
Expand Down Expand Up @@ -374,6 +375,87 @@ func (c *Client) LabelValuesCardinality(labelNames []string, selector string, li
return &lvalsResp, nil
}

type activeSeriesRequestConfig struct {
method string
useCompression bool
header http.Header
}

type ActiveSeriesOption func(*activeSeriesRequestConfig)

func WithEnableCompression() ActiveSeriesOption {
return func(c *activeSeriesRequestConfig) {
c.useCompression = true
c.header.Set("Accept-Encoding", "x-snappy-framed")
}
}

func WithRequestMethod(m string) ActiveSeriesOption {
return func(c *activeSeriesRequestConfig) {
c.method = m
}
}

func WithQueryShards(n int) ActiveSeriesOption {
return func(c *activeSeriesRequestConfig) {
c.header.Set("Sharding-Control", strconv.Itoa(n))
}
}

func (c *Client) ActiveSeries(selector string, options ...ActiveSeriesOption) (*api.ActiveSeriesResponse, error) {
cfg := activeSeriesRequestConfig{method: http.MethodGet, header: http.Header{"X-Scope-OrgID": []string{c.orgID}}}
for _, option := range options {
option(&cfg)
}

req, err := http.NewRequest(cfg.method, fmt.Sprintf("http://%s/prometheus/api/v1/cardinality/active_series", c.querierAddress), nil)
if err != nil {
return nil, err
}
req.Header = cfg.header

q := req.URL.Query()
q.Set("selector", selector)
switch cfg.method {
case http.MethodGet:
req.URL.RawQuery = q.Encode()
case http.MethodPost:
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Body = io.NopCloser(strings.NewReader(q.Encode()))
default:
return nil, fmt.Errorf("invalid method %s", cfg.method)
}

ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()

resp, err := c.httpClient.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
defer func(body io.ReadCloser) {
_, _ = io.ReadAll(body)
_ = body.Close()
}(resp.Body)

var bodyReader io.Reader = resp.Body
if resp.Header.Get("Content-Encoding") == "x-snappy-framed" {
bodyReader = s2.NewReader(bodyReader)
}

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(bodyReader)
return nil, fmt.Errorf("unexpected status code %d, body: %s", resp.StatusCode, body)
}

res := &api.ActiveSeriesResponse{}
err = json.NewDecoder(bodyReader).Decode(res)
if err != nil {
return nil, fmt.Errorf("error decoding active series response: %w", err)
}
return res, nil
}

// GetPrometheusMetadata fetches the metadata from the Prometheus endpoint /api/v1/metadata.
func (c *Client) GetPrometheusMetadata() (*http.Response, error) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s/prometheus/api/v1/metadata", c.querierAddress), nil)
Expand Down
139 changes: 139 additions & 0 deletions integration/query_frontend_active_series_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// SPDX-License-Identifier: AGPL-3.0-only
//go:build requires_docker

package integration

import (
"net/http"
"strconv"
"testing"
"time"

"github.com/grafana/e2e"
e2ecache "github.com/grafana/e2e/cache"
e2edb "github.com/grafana/e2e/db"
"github.com/prometheus/prometheus/model/labels"
"github.com/prometheus/prometheus/prompb"
"github.com/stretchr/testify/require"

"github.com/grafana/mimir/integration/e2emimir"
)

func TestActiveSeriesWithQueryShardingHTTP(t *testing.T) {
config := queryFrontendTestConfig{
queryStatsEnabled: true,
querySchedulerEnabled: true,
querySchedulerDiscoveryMode: "ring",
setup: func(t *testing.T, s *e2e.Scenario) (string, map[string]string) {
flags := mergeFlags(BlocksStorageFlags(), BlocksStorageS3Flags())
minio := e2edb.NewMinio(9000, flags["-blocks-storage.s3.bucket-name"])
require.NoError(t, s.StartAndWaitReady(minio))

return "", flags
},
}

runTestActiveSeriesWithQueryShardingHTTPTest(t, config)
}

func runTestActiveSeriesWithQueryShardingHTTPTest(t *testing.T, cfg queryFrontendTestConfig) {
s, err := e2e.NewScenario(networkName)
require.NoError(t, err)
defer s.Close()

memcached := e2ecache.NewMemcached()
consul := e2edb.NewConsul()
require.NoError(t, s.StartAndWaitReady(consul, memcached))

configFile, flags := cfg.setup(t, s)

flags = mergeFlags(flags, map[string]string{
"-query-frontend.cache-results": "true",
"-query-frontend.results-cache.backend": "memcached",
"-query-frontend.results-cache.memcached.addresses": "dns+" + memcached.NetworkEndpoint(e2ecache.MemcachedPort),
"-query-frontend.query-stats-enabled": strconv.FormatBool(cfg.queryStatsEnabled),
"-query-frontend.query-sharding-total-shards": "32",
"-query-frontend.query-sharding-max-sharded-queries": "128",
"-query-frontend.shard-active-series-queries": "true",
"-querier.cardinality-analysis-enabled": "true",
})

// Start the query-scheduler if enabled.
var queryScheduler *e2emimir.MimirService
if cfg.querySchedulerEnabled && cfg.querySchedulerDiscoveryMode == "dns" {
queryScheduler = e2emimir.NewQueryScheduler("query-scheduler", flags)
require.NoError(t, s.StartAndWaitReady(queryScheduler))
flags["-query-frontend.scheduler-address"] = queryScheduler.NetworkGRPCEndpoint()
flags["-querier.scheduler-address"] = queryScheduler.NetworkGRPCEndpoint()
} else if cfg.querySchedulerEnabled && cfg.querySchedulerDiscoveryMode == "ring" {
flags["-query-scheduler.service-discovery-mode"] = "ring"
flags["-query-scheduler.ring.store"] = "consul"
flags["-query-scheduler.ring.consul.hostname"] = consul.NetworkHTTPEndpoint()

queryScheduler = e2emimir.NewQueryScheduler("query-scheduler", flags)
require.NoError(t, s.StartAndWaitReady(queryScheduler))
}

// Start the query-frontend.
queryFrontend := e2emimir.NewQueryFrontend("query-frontend", flags, e2emimir.WithConfigFile(configFile))
require.NoError(t, s.Start(queryFrontend))

if !cfg.querySchedulerEnabled {
flags["-querier.frontend-address"] = queryFrontend.NetworkGRPCEndpoint()
}

// Start all other services.
ingester := e2emimir.NewIngester("ingester", consul.NetworkHTTPEndpoint(), flags, e2emimir.WithConfigFile(configFile))
distributor := e2emimir.NewDistributor("distributor", consul.NetworkHTTPEndpoint(), flags, e2emimir.WithConfigFile(configFile))
querier := e2emimir.NewQuerier("querier", consul.NetworkHTTPEndpoint(), flags, e2emimir.WithConfigFile(configFile))

require.NoError(t, s.StartAndWaitReady(querier, ingester, distributor))
require.NoError(t, s.WaitReady(queryFrontend))

// Check if we're discovering memcached or not.
require.NoError(t, queryFrontend.WaitSumMetrics(e2e.Equals(1), "thanos_cache_dns_provider_results"))
require.NoError(t, queryFrontend.WaitSumMetrics(e2e.Greater(0), "thanos_cache_dns_lookups_total"))

// Wait until distributor and querier have updated the ingesters ring.
require.NoError(t, distributor.WaitSumMetricsWithOptions(e2e.Equals(1), []string{"cortex_ring_members"}, e2e.WithLabelMatchers(
labels.MustNewMatcher(labels.MatchEqual, "name", "ingester"),
labels.MustNewMatcher(labels.MatchEqual, "state", "ACTIVE"))))

require.NoError(t, querier.WaitSumMetricsWithOptions(e2e.Equals(1), []string{"cortex_ring_members"}, e2e.WithLabelMatchers(
labels.MustNewMatcher(labels.MatchEqual, "name", "ingester"),
labels.MustNewMatcher(labels.MatchEqual, "state", "ACTIVE"))))

// Push series for the test user to Mimir.
c, err := e2emimir.NewClient(distributor.HTTPEndpoint(), queryFrontend.HTTPEndpoint(), "", "", userID)
require.NoError(t, err)
numSeries := 100
metricName := "test_metric"
now := time.Now()
series := make([]prompb.TimeSeries, numSeries)
for i := 0; i < numSeries; i++ {
ts, _, _ := e2e.GenerateSeries(metricName, now, prompb.Label{Name: "index", Value: strconv.Itoa(i)})
series[i] = ts[0]
}
res, err := c.Push(series)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)

// Query active series.
for _, options := range [][]e2emimir.ActiveSeriesOption{
{e2emimir.WithRequestMethod(http.MethodGet)},
{e2emimir.WithRequestMethod(http.MethodPost)},
{e2emimir.WithRequestMethod(http.MethodGet), e2emimir.WithEnableCompression()},
{e2emimir.WithRequestMethod(http.MethodPost), e2emimir.WithEnableCompression()},
{e2emimir.WithQueryShards(1)},
{e2emimir.WithQueryShards(12)},
} {
response, err := c.ActiveSeries(metricName, options...)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)
require.Len(t, response.Data, numSeries)
}

_, err = c.ActiveSeries(metricName, e2emimir.WithQueryShards(512))
require.Error(t, err)
require.Contains(t, err.Error(), "shard count 512 exceeds allowed maximum (128)")
}
6 changes: 3 additions & 3 deletions pkg/distributor/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1848,7 +1848,7 @@ func (d *Distributor) ActiveSeries(ctx context.Context, matchers []*labels.Match

stream, err := client.ActiveSeries(ctx, req)
if err != nil {
if errors.Is(err, context.Canceled) {
if errors.Is(util.WrapGrpcContextError(err), context.Canceled) {
return ignored{}, nil
}
level.Error(log).Log("msg", "error creating active series response stream", "err", err)
Expand All @@ -1858,7 +1858,7 @@ func (d *Distributor) ActiveSeries(ctx context.Context, matchers []*labels.Match

defer func() {
err = util.CloseAndExhaust[*ingester_client.ActiveSeriesResponse](stream)
if err != nil {
if err != nil && !errors.Is(util.WrapGrpcContextError(err), context.Canceled) {
level.Warn(d.log).Log("msg", "error closing active series response stream", "err", err)
}
}()
Expand All @@ -1869,7 +1869,7 @@ func (d *Distributor) ActiveSeries(ctx context.Context, matchers []*labels.Match
if errors.Is(err, io.EOF) {
break
}
if errors.Is(err, context.Canceled) {
if errors.Is(util.WrapGrpcContextError(err), context.Canceled) {
return ignored{}, nil
}
level.Error(log).Log("msg", "error receiving active series response", "err", err)
Expand Down
25 changes: 18 additions & 7 deletions pkg/frontend/querymiddleware/shard_active_series.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (
"net/http"
"net/url"
"strconv"
"strings"

"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/grafana/dskit/tenant"
"github.com/grafana/dskit/user"
jsoniter "github.com/json-iterator/go"
"github.com/klauspost/compress/s2"
"github.com/opentracing/opentracing-go"
Expand Down Expand Up @@ -138,7 +138,10 @@ func parseSelector(req *http.Request) (*parser.VectorSelector, error) {
func buildShardedRequests(ctx context.Context, req *http.Request, numRequests int, selector parser.Expr) ([]*http.Request, error) {
reqs := make([]*http.Request, numRequests)
for i := 0; i < numRequests; i++ {
reqs[i] = req.Clone(ctx)
r, err := http.NewRequestWithContext(ctx, http.MethodGet, req.URL.Path, http.NoBody)
if err != nil {
return nil, err
}

sharded, err := shardedSelector(numRequests, i, selector)
if err != nil {
Expand All @@ -147,11 +150,16 @@ func buildShardedRequests(ctx context.Context, req *http.Request, numRequests in

vals := url.Values{}
vals.Set("selector", sharded.String())
r.URL.RawQuery = vals.Encode()
// This is the field read by httpgrpc.FromHTTPRequest, so we need to populate it
// here to ensure the request parameter makes it to the querier.
r.RequestURI = r.URL.String()

reqs[i].Header.Set("Content-Type", "application/x-www-form-urlencoded")
reqs[i].Header.Del(totalShardsControlHeader)
reqs[i].Header.Del("Accept-Encoding")
reqs[i].Body = io.NopCloser(strings.NewReader(vals.Encode()))
if err := user.InjectOrgIDIntoHTTPRequest(ctx, r); err != nil {
return nil, err
}

reqs[i] = r
}

return reqs, nil
Expand Down Expand Up @@ -180,10 +188,13 @@ func doShardedRequests(ctx context.Context, upstreamRequests []*http.Request, ne

if resp.StatusCode != http.StatusOK {
span.LogFields(otlog.Int("statusCode", resp.StatusCode))
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == http.StatusRequestEntityTooLarge {
return errShardCountTooLow
}
var body []byte
if resp.Body != nil {
body, _ = io.ReadAll(resp.Body)
}
return fmt.Errorf("received unexpected response from upstream: status %d, body: %s", resp.StatusCode, string(body))
}

Expand Down
34 changes: 32 additions & 2 deletions pkg/frontend/querymiddleware/shard_active_series_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func Test_shardActiveSeriesMiddleware_RoundTrip(t *testing.T) {

validReq := func() *http.Request {
r := httptest.NewRequest("POST", "/active_series", strings.NewReader(`selector={__name__="metric"}`))
r.Header.Add("X-Scope-OrgID", "test")
r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
return r
}
Expand Down Expand Up @@ -239,6 +240,28 @@ func Test_shardActiveSeriesMiddleware_RoundTrip(t *testing.T) {
},
expectContentEncoding: encodingTypeSnappyFramed,
},
{
name: "builds correct request shards for GET requests",
request: func() *http.Request {
q := url.Values{}
q.Set("selector", "metric")
req, _ := http.NewRequest(http.MethodGet, "/active_series", nil)
req.URL.RawQuery = q.Encode()
req.Header.Add(totalShardsControlHeader, "2")
return req
},
validResponses: [][]labels.Labels{
{labels.FromStrings(labels.MetricName, "metric", "shard", "1")},
{labels.FromStrings(labels.MetricName, "metric", "shard", "2")},
},
checkResponseErr: noError,
expect: result{
Data: []labels.Labels{
labels.FromStrings(labels.MetricName, "metric", "shard", "1"),
labels.FromStrings(labels.MetricName, "metric", "shard", "2"),
},
},
},
}

for _, tt := range tests {
Expand All @@ -247,10 +270,17 @@ func Test_shardActiveSeriesMiddleware_RoundTrip(t *testing.T) {
// Stub upstream with valid or invalid responses.
var requestCount atomic.Int32
upstream := RoundTripFunc(func(r *http.Request) (*http.Response, error) {
defer func(Body io.ReadCloser) {
_ = Body.Close()
defer func(body io.ReadCloser) {
if body != nil {
_ = body.Close()
}
}(r.Body)

_, _, err := user.ExtractOrgIDFromHTTPRequest(r)
require.NoError(t, err)
_, err = user.ExtractOrgID(r.Context())
require.NoError(t, err)

requestCount.Inc()

if tt.errorResponse != nil {
Expand Down
Loading

0 comments on commit 52c39fb

Please sign in to comment.