diff --git a/docs/sources/shared/configuration.md b/docs/sources/shared/configuration.md index b287bdea5f37..5e119c6e5de1 100644 --- a/docs/sources/shared/configuration.md +++ b/docs/sources/shared/configuration.md @@ -2879,6 +2879,11 @@ wal: # common.path_prefix is set then common.path_prefix will be used. # CLI flag: -ingester.shutdown-marker-path [shutdown_marker_path: | default = ""] + +# Interval at which the ingester ownedStreamService checks for changes in the +# ring to recalculate owned streams. +# CLI flag: -ingester.owned-streams-check-interval +[owned_streams_check_interval: | default = 30s] ``` ### ingester_client @@ -3540,7 +3545,7 @@ When a memberlist config with atleast 1 join_members is defined, kvstore of type # The timeout for establishing a connection with a remote node, and for # read/write operations. # CLI flag: -memberlist.stream-timeout -[stream_timeout: | default = 10s] +[stream_timeout: | default = 2s] # Multiplication factor used when sending out messages (factor * log(N+1)). # CLI flag: -memberlist.retransmit-factor @@ -4755,6 +4760,10 @@ Configures the `server` of the launched module(s). # CLI flag: -server.grpc-conn-limit [grpc_listen_conn_limit: | default = 0] +# Enables PROXY protocol. +# CLI flag: -server.proxy-protocol-enabled +[proxy_protocol_enabled: | default = false] + # Comma-separated list of cipher suites to use. If blank, the default Go cipher # suites is used. # CLI flag: -server.tls-cipher-suites @@ -4909,6 +4918,21 @@ grpc_tls_config: # CLI flag: -server.grpc.num-workers [grpc_server_num_workers: | default = 0] +# If true, the request_message_bytes, response_message_bytes, and +# inflight_requests metrics will be tracked. Enabling this option prevents the +# use of memory pools for parsing gRPC request bodies and may lead to more +# memory allocations. +# CLI flag: -server.grpc.stats-tracking-enabled +[grpc_server_stats_tracking_enabled: | default = true] + +# If true, gGPC's buffer pools will be used to handle incoming requests. +# Enabling this feature can reduce memory allocation, but also requires +# disabling GRPC server stats tracking by setting +# `server.grpc.stats-tracking-enabled=false`. This is an experimental gRPC +# feature, so it might be removed in a future version of the gRPC library. +# CLI flag: -server.grpc.recv-buffer-pools-enabled +[grpc_server_recv_buffer_pools_enabled: | default = false] + # Output log messages in the given format. Valid formats: [logfmt, json] # CLI flag: -log.format [log_format: | default = "logfmt"] @@ -4922,6 +4946,11 @@ grpc_tls_config: # CLI flag: -server.log-source-ips-enabled [log_source_ips_enabled: | default = false] +# Log all source IPs instead of only the originating one. Only used if +# server.log-source-ips-enabled is true +# CLI flag: -server.log-source-ips-full +[log_source_ips_full: | default = false] + # Header field storing the source IPs. Only used if # server.log-source-ips-enabled is true. If not set the default Forwarded, # X-Real-IP and X-Forwarded-For headers are used diff --git a/go.mod b/go.mod index 051a18d2292f..b4ed8ae7dea0 100644 --- a/go.mod +++ b/go.mod @@ -50,9 +50,9 @@ require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/grafana/cloudflare-go v0.0.0-20230110200409-c627cf6792f2 - github.com/grafana/dskit v0.0.0-20240104111617-ea101a3b86eb + github.com/grafana/dskit v0.0.0-20240528015923-27d7d41066d3 github.com/grafana/go-gelf/v2 v2.0.1 - github.com/grafana/gomemcache v0.0.0-20231204155601-7de47a8c3cb0 + github.com/grafana/gomemcache v0.0.0-20240229205252-cd6a66d6fb56 github.com/grafana/regexp v0.0.0-20221122212121-6b5c0a4cb7fd github.com/grafana/tail v0.0.0-20230510142333-77b18831edf0 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 @@ -153,6 +153,7 @@ require ( github.com/dlclark/regexp2 v1.4.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect + github.com/pires/go-proxyproto v0.7.0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect @@ -312,7 +313,6 @@ require ( github.com/sercand/kuberesolver/v5 v5.1.1 // indirect github.com/shopspring/decimal v1.2.0 // indirect github.com/sirupsen/logrus v1.9.3 - github.com/soheilhy/cmux v0.1.5 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spf13/cast v1.3.1 // indirect github.com/spf13/pflag v1.0.5 // indirect diff --git a/go.sum b/go.sum index 37ea77d49215..0c969948e109 100644 --- a/go.sum +++ b/go.sum @@ -1017,14 +1017,14 @@ github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWm github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grafana/cloudflare-go v0.0.0-20230110200409-c627cf6792f2 h1:qhugDMdQ4Vp68H0tp/0iN17DM2ehRo1rLEdOFe/gB8I= github.com/grafana/cloudflare-go v0.0.0-20230110200409-c627cf6792f2/go.mod h1:w/aiO1POVIeXUQyl0VQSZjl5OAGDTL5aX+4v0RA1tcw= -github.com/grafana/dskit v0.0.0-20240104111617-ea101a3b86eb h1:AWE6+kvtE18HP+lRWNUCyvymyrFSXs6TcS2vXIXGIuw= -github.com/grafana/dskit v0.0.0-20240104111617-ea101a3b86eb/go.mod h1:kkWM4WUV230bNG3urVRWPBnSJHs64y/0RmWjftnnn0c= +github.com/grafana/dskit v0.0.0-20240528015923-27d7d41066d3 h1:k8vINlI4w+RYc37NRwQlRe/IHYoEbu6KAe2XdGDeV1U= +github.com/grafana/dskit v0.0.0-20240528015923-27d7d41066d3/go.mod h1:HvSf3uf8Ps2vPpzHeAFyZTdUcbVr+Rxpq1xcx7J/muc= github.com/grafana/go-gelf/v2 v2.0.1 h1:BOChP0h/jLeD+7F9mL7tq10xVkDG15he3T1zHuQaWak= github.com/grafana/go-gelf/v2 v2.0.1/go.mod h1:lexHie0xzYGwCgiRGcvZ723bSNyNI8ZRD4s0CLobh90= github.com/grafana/gocql v0.0.0-20200605141915-ba5dc39ece85 h1:xLuzPoOzdfNb/RF/IENCw+oLVdZB4G21VPhkHBgwSHY= github.com/grafana/gocql v0.0.0-20200605141915-ba5dc39ece85/go.mod h1:crI9WX6p0IhrqB+DqIUHulRW853PaNFf7o4UprV//3I= -github.com/grafana/gomemcache v0.0.0-20231204155601-7de47a8c3cb0 h1:aLBiDMjTtXx2800iCIp+8kdjIlvGX0MF/zICQMQO2qU= -github.com/grafana/gomemcache v0.0.0-20231204155601-7de47a8c3cb0/go.mod h1:PGk3RjYHpxMM8HFPhKKo+vve3DdlPUELZLSDEFehPuU= +github.com/grafana/gomemcache v0.0.0-20240229205252-cd6a66d6fb56 h1:X8IKQ0wu40wpvYcKfBcc5T4QnhdQjUhtUtB/1CY89lE= +github.com/grafana/gomemcache v0.0.0-20240229205252-cd6a66d6fb56/go.mod h1:PGk3RjYHpxMM8HFPhKKo+vve3DdlPUELZLSDEFehPuU= github.com/grafana/jsonparser v0.0.0-20240209175146-098958973a2d h1:YwbJJ/PrVWVdnR+j/EAVuazdeP+Za5qbiH1Vlr+wFXs= github.com/grafana/jsonparser v0.0.0-20240209175146-098958973a2d/go.mod h1:796sq+UcONnSlzA3RtlBZ+b/hrerkZXiEmO8oMjyRwY= github.com/grafana/memberlist v0.3.1-0.20220714140823-09ffed8adbbe h1:yIXAAbLswn7VNWBIvM71O2QsgfgW9fRXZNR0DXe6pDU= @@ -1541,6 +1541,8 @@ github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi github.com/pierrec/lz4 v2.5.2+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4/v4 v4.1.18 h1:xaKrnTkyoqfh1YItXl56+6KJNVYWlEEPuAQW9xsplYQ= github.com/pierrec/lz4/v4 v4.1.18/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs= +github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -1697,8 +1699,6 @@ github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:s github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/softlayer/softlayer-go v0.0.0-20180806151055-260589d94c7d/go.mod h1:Cw4GTlQccdRGSEf6KiMju767x0NEHE0YIVPJSaXjlsw= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= -github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= -github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= github.com/soniah/gosnmp v1.25.0/go.mod h1:8YvfZxH388NIIw2A+X5z2Oh97VcNhtmxDLt5QeUzVuQ= github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/sony/gobreaker v0.5.0 h1:dRCvqm0P490vZPmy7ppEk2qCnCieBooFJ+YoXGYB+yg= diff --git a/pkg/ingester/checkpoint_test.go b/pkg/ingester/checkpoint_test.go index d530d937d42f..1b0c76466dc1 100644 --- a/pkg/ingester/checkpoint_test.go +++ b/pkg/ingester/checkpoint_test.go @@ -70,7 +70,9 @@ func TestIngesterWAL(t *testing.T) { } } - i, err := New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) require.Nil(t, services.StartAndAwaitRunning(context.Background(), i)) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -113,7 +115,7 @@ func TestIngesterWAL(t *testing.T) { expectCheckpoint(t, walDir, false, time.Second) // restart the ingester - i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil) + i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck require.Nil(t, services.StartAndAwaitRunning(context.Background(), i)) @@ -127,7 +129,7 @@ func TestIngesterWAL(t *testing.T) { require.Nil(t, services.StopAndAwaitTerminated(context.Background(), i)) // restart the ingester - i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil) + i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck require.Nil(t, services.StartAndAwaitRunning(context.Background(), i)) @@ -150,7 +152,9 @@ func TestIngesterWALIgnoresStreamLimits(t *testing.T) { } } - i, err := New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) require.Nil(t, services.StartAndAwaitRunning(context.Background(), i)) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -196,7 +200,7 @@ func TestIngesterWALIgnoresStreamLimits(t *testing.T) { require.NoError(t, err) // restart the ingester - i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil) + i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck require.Nil(t, services.StartAndAwaitRunning(context.Background(), i)) @@ -253,7 +257,9 @@ func TestIngesterWALBackpressureSegments(t *testing.T) { } } - i, err := New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) require.Nil(t, services.StartAndAwaitRunning(context.Background(), i)) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -274,7 +280,7 @@ func TestIngesterWALBackpressureSegments(t *testing.T) { expectCheckpoint(t, walDir, false, time.Second) // restart the ingester, ensuring we replayed from WAL. - i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil) + i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck require.Nil(t, services.StartAndAwaitRunning(context.Background(), i)) @@ -295,7 +301,9 @@ func TestIngesterWALBackpressureCheckpoint(t *testing.T) { } } - i, err := New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) require.Nil(t, services.StartAndAwaitRunning(context.Background(), i)) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -316,7 +324,7 @@ func TestIngesterWALBackpressureCheckpoint(t *testing.T) { require.Nil(t, services.StopAndAwaitTerminated(context.Background(), i)) // restart the ingester, ensuring we can replay from the checkpoint as well. - i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil) + i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck require.Nil(t, services.StartAndAwaitRunning(context.Background(), i)) @@ -591,7 +599,9 @@ func TestIngesterWALReplaysUnorderedToOrdered(t *testing.T) { } } - i, err := New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) require.Nil(t, services.StartAndAwaitRunning(context.Background(), i)) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -663,7 +673,7 @@ func TestIngesterWALReplaysUnorderedToOrdered(t *testing.T) { require.NoError(t, err) // restart the ingester - i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil) + i, err = New(ingesterConfig, client.Config{}, newStore(), limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokit_log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck require.Nil(t, services.StartAndAwaitRunning(context.Background(), i)) diff --git a/pkg/ingester/flush_test.go b/pkg/ingester/flush_test.go index edd6084a2741..460a50ffc8fa 100644 --- a/pkg/ingester/flush_test.go +++ b/pkg/ingester/flush_test.go @@ -337,10 +337,12 @@ func newTestStore(t require.TestingT, cfg Config, walOverride WAL) (*testStore, chunks: map[string][]chunk.Chunk{}, } + readRingMock := mockReadRingWithOneActiveIngester() + limits, err := validation.NewOverrides(defaultLimitsTestConfig(), nil) require.NoError(t, err) - ing, err := New(cfg, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokitlog.NewNopLogger(), nil) + ing, err := New(cfg, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, gokitlog.NewNopLogger(), nil, readRingMock) require.NoError(t, err) require.NoError(t, services.StartAndAwaitRunning(context.Background(), ing)) @@ -376,6 +378,7 @@ func defaultIngesterTestConfig(t testing.TB) Config { cfg.BlockSize = 256 * 1024 cfg.TargetChunkSize = 1500 * 1024 cfg.WAL.Enabled = false + cfg.OwnedStreamsCheckInterval = 1 * time.Second return cfg } diff --git a/pkg/ingester/ingester.go b/pkg/ingester/ingester.go index 1a89aebe6ef9..464204334afc 100644 --- a/pkg/ingester/ingester.go +++ b/pkg/ingester/ingester.go @@ -122,6 +122,8 @@ type Config struct { MaxDroppedStreams int `yaml:"max_dropped_streams"` ShutdownMarkerPath string `yaml:"shutdown_marker_path"` + + OwnedStreamsCheckInterval time.Duration `yaml:"owned_streams_check_interval" doc:"description=Interval at which the ingester ownedStreamService checks for changes in the ring to recalculate owned streams."` } // RegisterFlags registers the flags. @@ -149,6 +151,7 @@ func (cfg *Config) RegisterFlags(f *flag.FlagSet) { f.IntVar(&cfg.IndexShards, "ingester.index-shards", index.DefaultIndexShards, "Shard factor used in the ingesters for the in process reverse index. This MUST be evenly divisible by ALL schema shard factors or Loki will not start.") f.IntVar(&cfg.MaxDroppedStreams, "ingester.tailer.max-dropped-streams", 10, "Maximum number of dropped streams to keep in memory during tailing.") f.StringVar(&cfg.ShutdownMarkerPath, "ingester.shutdown-marker-path", "", "Path where the shutdown marker file is stored. If not set and common.path_prefix is set then common.path_prefix will be used.") + f.DurationVar(&cfg.OwnedStreamsCheckInterval, "ingester.owned-streams-check-interval", 30*time.Second, "Interval at which the ingester ownedStreamService checks for changes in the ring to recalculate owned streams.") } func (cfg *Config) Validate() error { @@ -262,10 +265,14 @@ type Ingester struct { writeLogManager *writefailures.Manager customStreamsTracker push.UsageTracker + + // recalculateOwnedStreams periodically checks the ring for changes and recalculates owned streams for each instance. + readRing ring.ReadRing + recalculateOwnedStreams *recalculateOwnedStreams } // New makes a new Ingester. -func New(cfg Config, clientConfig client.Config, store Store, limits Limits, configs *runtime.TenantConfigs, registerer prometheus.Registerer, writeFailuresCfg writefailures.Cfg, metricsNamespace string, logger log.Logger, customStreamsTracker push.UsageTracker) (*Ingester, error) { +func New(cfg Config, clientConfig client.Config, store Store, limits Limits, configs *runtime.TenantConfigs, registerer prometheus.Registerer, writeFailuresCfg writefailures.Cfg, metricsNamespace string, logger log.Logger, customStreamsTracker push.UsageTracker, readRing ring.ReadRing) (*Ingester, error) { if cfg.ingesterClientFactory == nil { cfg.ingesterClientFactory = client.New } @@ -294,6 +301,7 @@ func New(cfg Config, clientConfig client.Config, store Store, limits Limits, con streamRateCalculator: NewStreamRateCalculator(), writeLogManager: writefailures.NewManager(logger, registerer, writeFailuresCfg, configs, "ingester"), customStreamsTracker: customStreamsTracker, + readRing: readRing, } i.replayController = newReplayController(metrics, cfg.WAL, &replayFlusher{i}) @@ -343,6 +351,8 @@ func New(cfg Config, clientConfig client.Config, store Store, limits Limits, con i.SetExtractorWrapper(i.cfg.SampleExtractorWrapper) } + i.recalculateOwnedStreams = newRecalculateOwnedStreams(i.getInstances, i.lifecycler.ID, i.readRing, cfg.OwnedStreamsCheckInterval, util_log.Logger) + return i, nil } @@ -536,6 +546,16 @@ func (i *Ingester) starting(ctx context.Context) error { i.setPrepareShutdown() } + err = i.recalculateOwnedStreams.StartAsync(ctx) + if err != nil { + return fmt.Errorf("can not start recalculate owned streams service: %w", err) + } + + err = i.lifecycler.AwaitRunning(ctx) + if err != nil { + return fmt.Errorf("can not ensure recalculate owned streams service is running: %w", err) + } + // start our loop i.loopDone.Add(1) go i.loop() diff --git a/pkg/ingester/ingester_test.go b/pkg/ingester/ingester_test.go index 6bb27ad645cc..444e1317e697 100644 --- a/pkg/ingester/ingester_test.go +++ b/pkg/ingester/ingester_test.go @@ -2,6 +2,7 @@ package ingester import ( "fmt" + math "math" "net" "net/http" "net/http/httptest" @@ -16,8 +17,10 @@ import ( "github.com/grafana/dskit/flagext" "github.com/grafana/dskit/httpgrpc" "github.com/grafana/dskit/middleware" + "github.com/grafana/dskit/ring" "github.com/grafana/dskit/services" "github.com/grafana/dskit/user" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/common/model" "github.com/prometheus/prometheus/model/labels" "github.com/stretchr/testify/require" @@ -58,7 +61,9 @@ func TestPrepareShutdownMarkerPathNotSet(t *testing.T) { chunks: map[string][]chunk.Chunk{}, } - i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + mockRing := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, mockRing) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -81,7 +86,9 @@ func TestPrepareShutdown(t *testing.T) { chunks: map[string][]chunk.Chunk{}, } - i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -142,7 +149,9 @@ func TestIngester_GetStreamRates_Correctness(t *testing.T) { chunks: map[string][]chunk.Chunk{}, } - i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -173,8 +182,9 @@ func BenchmarkGetStreamRatesAllocs(b *testing.B) { store := &mockStore{ chunks: map[string][]chunk.Chunk{}, } + readRingMock := mockReadRingWithOneActiveIngester() - i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(b, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -198,7 +208,9 @@ func TestIngester(t *testing.T) { chunks: map[string][]chunk.Chunk{}, } - i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -383,7 +395,9 @@ func TestIngesterStreamLimitExceeded(t *testing.T) { chunks: map[string][]chunk.Chunk{}, } - i, err := New(ingesterConfig, client.Config{}, store, overrides, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, store, overrides, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -803,7 +817,9 @@ func Test_InMemoryLabels(t *testing.T) { chunks: map[string][]chunk.Chunk{}, } - i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -856,8 +872,9 @@ func TestIngester_GetDetectedLabels(t *testing.T) { store := &mockStore{ chunks: map[string][]chunk.Chunk{}, } + readRingMock := mockReadRingWithOneActiveIngester() - i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -919,8 +936,9 @@ func TestIngester_GetDetectedLabelsWithQuery(t *testing.T) { store := &mockStore{ chunks: map[string][]chunk.Chunk{}, } + readRingMock := mockReadRingWithOneActiveIngester() - i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + i, err := New(ingesterConfig, client.Config{}, store, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) defer services.StopAndAwaitTerminated(context.Background(), i) //nolint:errcheck @@ -1286,8 +1304,9 @@ func TestStats(t *testing.T) { ingesterConfig := defaultIngesterTestConfig(t) limits, err := validation.NewOverrides(defaultLimitsTestConfig(), nil) require.NoError(t, err) + readRingMock := mockReadRingWithOneActiveIngester() - i, err := New(ingesterConfig, client.Config{}, &mockStore{}, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + i, err := New(ingesterConfig, client.Config{}, &mockStore{}, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) i.instances["test"] = defaultInstance(t) @@ -1313,8 +1332,9 @@ func TestVolume(t *testing.T) { ingesterConfig := defaultIngesterTestConfig(t) limits, err := validation.NewOverrides(defaultLimitsTestConfig(), nil) require.NoError(t, err) + readRingMock := mockReadRingWithOneActiveIngester() - i, err := New(ingesterConfig, client.Config{}, &mockStore{}, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + i, err := New(ingesterConfig, client.Config{}, &mockStore{}, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) i.instances["test"] = defaultInstance(t) @@ -1392,8 +1412,9 @@ func createIngesterServer(t *testing.T, ingesterConfig Config) (ingesterClient, t.Helper() limits, err := validation.NewOverrides(defaultLimitsTestConfig(), nil) require.NoError(t, err) + readRingMock := mockReadRingWithOneActiveIngester() - ing, err := New(ingesterConfig, client.Config{}, &mockStore{}, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + ing, err := New(ingesterConfig, client.Config{}, &mockStore{}, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) listener := bufconn.Listen(1024 * 1024) @@ -1472,3 +1493,150 @@ func jsonLine(ts int64, i int) string { } return fmt.Sprintf(`{"e":"f", "h":"i", "j":"k", "g":"h", "ts":"%d"}`, ts) } + +type readRingMock struct { + replicationSet ring.ReplicationSet + getAllHealthyCallsCount int + tokenRangesByIngester map[string]ring.TokenRanges +} + +func (r *readRingMock) HealthyInstancesCount() int { + return len(r.replicationSet.Instances) +} + +func newReadRingMock(ingesters []ring.InstanceDesc, maxErrors int) *readRingMock { + return &readRingMock{ + replicationSet: ring.ReplicationSet{ + Instances: ingesters, + MaxErrors: maxErrors, + }, + } +} + +func (r *readRingMock) Describe(_ chan<- *prometheus.Desc) { +} + +func (r *readRingMock) Collect(_ chan<- prometheus.Metric) { +} + +func (r *readRingMock) Get(_ uint32, _ ring.Operation, _ []ring.InstanceDesc, _ []string, _ []string) (ring.ReplicationSet, error) { + return r.replicationSet, nil +} + +func (r *readRingMock) ShuffleShard(_ string, size int) ring.ReadRing { + // pass by value to copy + return func(r readRingMock) *readRingMock { + r.replicationSet.Instances = r.replicationSet.Instances[:size] + return &r + }(*r) +} + +func (r *readRingMock) BatchGet(_ []uint32, _ ring.Operation) ([]ring.ReplicationSet, error) { + return []ring.ReplicationSet{r.replicationSet}, nil +} + +func (r *readRingMock) GetAllHealthy(_ ring.Operation) (ring.ReplicationSet, error) { + r.getAllHealthyCallsCount++ + return r.replicationSet, nil +} + +func (r *readRingMock) GetReplicationSetForOperation(_ ring.Operation) (ring.ReplicationSet, error) { + return r.replicationSet, nil +} + +func (r *readRingMock) ReplicationFactor() int { + return 1 +} + +func (r *readRingMock) InstancesCount() int { + return len(r.replicationSet.Instances) +} + +func (r *readRingMock) InstancesInZoneCount(_ string) int { + return len(r.replicationSet.Instances) +} + +func (r *readRingMock) InstancesWithTokensCount() int { + return len(r.replicationSet.Instances) +} + +func (r *readRingMock) InstancesWithTokensInZoneCount(_ string) int { + return len(r.replicationSet.Instances) +} + +func (r *readRingMock) ZonesCount() int { + return 1 +} + +func (r *readRingMock) Subring(_ uint32, _ int) ring.ReadRing { + return r +} + +func (r *readRingMock) HasInstance(instanceID string) bool { + for _, ing := range r.replicationSet.Instances { + if ing.Addr != instanceID { + return true + } + } + return false +} + +func (r *readRingMock) ShuffleShardWithLookback(_ string, _ int, _ time.Duration, _ time.Time) ring.ReadRing { + return r +} + +func (r *readRingMock) CleanupShuffleShardCache(_ string) {} + +func (r *readRingMock) GetInstanceState(_ string) (ring.InstanceState, error) { + return 0, nil +} + +func (r *readRingMock) GetTokenRangesForInstance(instance string) (ring.TokenRanges, error) { + if r.tokenRangesByIngester != nil { + ranges, exists := r.tokenRangesByIngester[instance] + if !exists { + return nil, ring.ErrInstanceNotFound + } + return ranges, nil + } + tr := ring.TokenRanges{0, math.MaxUint32} + return tr, nil +} + +func mockReadRingWithOneActiveIngester() *readRingMock { + return newReadRingMock([]ring.InstanceDesc{ + {Addr: "test", Timestamp: time.Now().UnixNano(), State: ring.ACTIVE, Tokens: []uint32{1, 2, 3}}, + }, 0) +} + +func TestUpdateOwnedStreams(t *testing.T) { + ingesterConfig := defaultIngesterTestConfig(t) + limits, err := validation.NewOverrides(defaultLimitsTestConfig(), nil) + require.NoError(t, err) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, &mockStore{}, limits, runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) + require.NoError(t, err) + + i.instances["test"] = defaultInstance(t) + + tt := time.Now().Add(-5 * time.Minute) + err = i.instances["test"].Push(context.Background(), &logproto.PushRequest{Streams: []logproto.Stream{ + // both label sets have FastFingerprint=e002a3a451262627 + {Labels: "{app=\"l\",uniq0=\"0\",uniq1=\"1\"}", Entries: entries(5, tt.Add(time.Minute))}, + {Labels: "{uniq0=\"1\",app=\"m\",uniq1=\"1\"}", Entries: entries(5, tt)}, + + // e002a3a451262247 + {Labels: "{app=\"l\",uniq0=\"1\",uniq1=\"0\"}", Entries: entries(5, tt.Add(time.Minute))}, + {Labels: "{uniq1=\"0\",app=\"m\",uniq0=\"0\"}", Entries: entries(5, tt)}, + + // e002a2a4512624f4 + {Labels: "{app=\"l\",uniq0=\"0\",uniq1=\"0\"}", Entries: entries(5, tt.Add(time.Minute))}, + {Labels: "{uniq0=\"1\",uniq1=\"0\",app=\"m\"}", Entries: entries(5, tt)}, + }}) + require.NoError(t, err) + + // streams are pushed, let's check owned stream counts + ownedStreams := i.instances["test"].ownedStreamsSvc.getOwnedStreamCount() + require.Equal(t, 8, ownedStreams) +} diff --git a/pkg/ingester/instance.go b/pkg/ingester/instance.go index ecef3f10347b..65389a3cb04a 100644 --- a/pkg/ingester/instance.go +++ b/pkg/ingester/instance.go @@ -16,6 +16,7 @@ import ( "github.com/go-kit/log/level" "github.com/grafana/dskit/httpgrpc" + "github.com/grafana/dskit/ring" "github.com/opentracing/opentracing-go" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" @@ -186,6 +187,7 @@ func newInstance( customStreamsTracker: customStreamsTracker, } i.mapper = NewFPMapper(i.getLabelsFromFingerprint) + return i, err } @@ -375,10 +377,7 @@ func (i *instance) createStreamByFP(ls labels.Labels, fp model.Fingerprint) (*st s := newStream(chunkfmt, headfmt, i.cfg, i.limiter, i.instanceID, fp, sortedLabels, i.limiter.UnorderedWrites(i.instanceID), i.streamRateCalculator, i.metrics, i.writeFailures) - i.streamsCreatedTotal.Inc() - memoryStreams.WithLabelValues(i.instanceID).Inc() - memoryStreamsLabelsBytes.Add(float64(len(s.labels.String()))) - i.addTailersToNewStream(s) + i.onStreamCreated(s) return s, nil } @@ -1175,3 +1174,23 @@ func minTs(stream *logproto.Stream) model.Time { } return model.TimeFromUnixNano(streamMinTs) } + +// For each stream, we check if the stream is owned by the ingester or not and increment/decrement the owned stream count. +func (i *instance) updateOwnedStreams(ownedTokenRange ring.TokenRanges) error { + var err error + i.streams.WithLock(func() { + i.ownedStreamsSvc.resetStreamCounts() + err = i.streams.ForEach(func(s *stream) (bool, error) { + if ownedTokenRange.IncludesKey(uint32(s.fp)) { + i.ownedStreamsSvc.incOwnedStreamCount() + } else { + i.ownedStreamsSvc.incNotOwnedStreamCount() + } + return true, nil + }) + }) + if err != nil { + return fmt.Errorf("error checking streams ownership: %w", err) + } + return nil +} diff --git a/pkg/ingester/instance_test.go b/pkg/ingester/instance_test.go index 3055a7fb0c5b..f5e959165481 100644 --- a/pkg/ingester/instance_test.go +++ b/pkg/ingester/instance_test.go @@ -50,6 +50,7 @@ func defaultConfig() *Config { MaxBackoff: 10 * time.Second, MaxRetries: 1, }, + OwnedStreamsCheckInterval: 1 * time.Second, } if err := cfg.Validate(); err != nil { panic(errors.Wrap(err, "error building default test config")) @@ -1103,7 +1104,8 @@ func TestStreamShardingUsage(t *testing.T) { t.Run("invalid push returns error", func(t *testing.T) { tracker := &mockUsageTracker{} - i, _ := newInstance(&Config{IndexShards: 1}, defaultPeriodConfigs, customTenant1, limiter, loki_runtime.DefaultTenantConfigs(), noopWAL{}, NilMetrics, &OnceSwitch{}, nil, nil, nil, NewStreamRateCalculator(), nil, tracker) + + i, _ := newInstance(&Config{IndexShards: 1, OwnedStreamsCheckInterval: 1 * time.Second}, defaultPeriodConfigs, customTenant1, limiter, loki_runtime.DefaultTenantConfigs(), noopWAL{}, NilMetrics, &OnceSwitch{}, nil, nil, nil, NewStreamRateCalculator(), nil, tracker) ctx := context.Background() err = i.Push(ctx, &logproto.PushRequest{ @@ -1123,7 +1125,7 @@ func TestStreamShardingUsage(t *testing.T) { }) t.Run("valid push returns no error", func(t *testing.T) { - i, _ := newInstance(&Config{IndexShards: 1}, defaultPeriodConfigs, customTenant2, limiter, loki_runtime.DefaultTenantConfigs(), noopWAL{}, NilMetrics, &OnceSwitch{}, nil, nil, nil, NewStreamRateCalculator(), nil, nil) + i, _ := newInstance(&Config{IndexShards: 1, OwnedStreamsCheckInterval: 1 * time.Second}, defaultPeriodConfigs, customTenant2, limiter, loki_runtime.DefaultTenantConfigs(), noopWAL{}, NilMetrics, &OnceSwitch{}, nil, nil, nil, NewStreamRateCalculator(), nil, nil) ctx := context.Background() err = i.Push(ctx, &logproto.PushRequest{ diff --git a/pkg/ingester/owned_streams.go b/pkg/ingester/owned_streams.go index 3be6fb40fdd8..7a405d684dc5 100644 --- a/pkg/ingester/owned_streams.go +++ b/pkg/ingester/owned_streams.go @@ -3,15 +3,16 @@ package ingester import ( "sync" + "github.com/grafana/dskit/services" "go.uber.org/atomic" ) type ownedStreamService struct { - tenantID string - limiter *Limiter - fixedLimit *atomic.Int32 + services.Service - //todo: implement job to recalculate it + tenantID string + limiter *Limiter + fixedLimit *atomic.Int32 ownedStreamCount int notOwnedStreamCount int lock sync.RWMutex @@ -23,6 +24,7 @@ func newOwnedStreamService(tenantID string, limiter *Limiter) *ownedStreamServic limiter: limiter, fixedLimit: atomic.NewInt32(0), } + svc.updateFixedLimit() return svc } @@ -48,6 +50,12 @@ func (s *ownedStreamService) incOwnedStreamCount() { s.ownedStreamCount++ } +func (s *ownedStreamService) incNotOwnedStreamCount() { + s.lock.Lock() + defer s.lock.Unlock() + s.notOwnedStreamCount++ +} + func (s *ownedStreamService) decOwnedStreamCount() { s.lock.Lock() defer s.lock.Unlock() @@ -57,3 +65,10 @@ func (s *ownedStreamService) decOwnedStreamCount() { } s.ownedStreamCount-- } + +func (s *ownedStreamService) resetStreamCounts() { + s.lock.Lock() + defer s.lock.Unlock() + s.ownedStreamCount = 0 + s.notOwnedStreamCount = 0 +} diff --git a/pkg/ingester/owned_streams_test.go b/pkg/ingester/owned_streams_test.go index 759927a1d0cf..876954b8579a 100644 --- a/pkg/ingester/owned_streams_test.go +++ b/pkg/ingester/owned_streams_test.go @@ -33,9 +33,12 @@ func Test_OwnedStreamService(t *testing.T) { service.incOwnedStreamCount() require.Equal(t, 3, service.getOwnedStreamCount()) - // simulate the effect from the recalculation job + service.incOwnedStreamCount() + service.decOwnedStreamCount() service.notOwnedStreamCount = 1 service.ownedStreamCount = 2 + require.Equal(t, 2, service.getOwnedStreamCount()) + require.Equal(t, 1, service.notOwnedStreamCount) service.decOwnedStreamCount() require.Equal(t, 2, service.getOwnedStreamCount(), "owned stream count must be decremented only when notOwnedStreamCount is set to 0") @@ -63,4 +66,13 @@ func Test_OwnedStreamService(t *testing.T) { group.Wait() require.Equal(t, 1, service.getOwnedStreamCount(), "owned stream count must not be changed") + + // simulate the effect from the recalculation job + service.notOwnedStreamCount = 1 + service.ownedStreamCount = 2 + + service.resetStreamCounts() + + require.Equal(t, 0, service.getOwnedStreamCount()) + require.Equal(t, 0, service.notOwnedStreamCount) } diff --git a/pkg/ingester/recalculate_owned_streams.go b/pkg/ingester/recalculate_owned_streams.go new file mode 100644 index 000000000000..1346033cb581 --- /dev/null +++ b/pkg/ingester/recalculate_owned_streams.go @@ -0,0 +1,91 @@ +package ingester + +import ( + "context" + "errors" + "time" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" + "github.com/grafana/dskit/ring" + "github.com/grafana/dskit/services" +) + +var ownedStreamRingOp = ring.NewOp([]ring.InstanceState{ring.PENDING, ring.JOINING, ring.ACTIVE, ring.LEAVING}, nil) + +type recalculateOwnedStreams struct { + services.Service + + logger log.Logger + + instancesSupplier func() []*instance + ingesterID string + previousRing ring.ReplicationSet + ingestersRing ring.ReadRing + ticker *time.Ticker +} + +func newRecalculateOwnedStreams(instancesSupplier func() []*instance, ingesterID string, ring ring.ReadRing, ringPollInterval time.Duration, logger log.Logger) *recalculateOwnedStreams { + svc := &recalculateOwnedStreams{ + ingestersRing: ring, + instancesSupplier: instancesSupplier, + ingesterID: ingesterID, + logger: logger, + } + svc.Service = services.NewTimerService(ringPollInterval, nil, svc.iteration, nil) + return svc +} + +func (s *recalculateOwnedStreams) iteration(_ context.Context) error { + s.recalculate() + return nil +} + +func (s *recalculateOwnedStreams) recalculate() { + ringChanged, err := s.checkRingForChanges() + if err != nil { + level.Error(s.logger).Log("msg", "failed to check ring for changes", "err", err) + return + } + if !ringChanged { + return + } + ownedTokenRange, err := s.getTokenRangesForIngester() + if err != nil { + level.Error(s.logger).Log("msg", "failed to get token ranges for ingester", "err", err) + return + } + + for _, instance := range s.instancesSupplier() { + if !instance.limiter.limits.UseOwnedStreamCount(instance.instanceID) { + continue + } + err = instance.updateOwnedStreams(ownedTokenRange) + if err != nil { + level.Error(s.logger).Log("msg", "failed to update owned streams", "err", err) + } + } +} + +func (s *recalculateOwnedStreams) checkRingForChanges() (bool, error) { + rs, err := s.ingestersRing.GetAllHealthy(ownedStreamRingOp) + if err != nil { + return false, err + } + + ringChanged := ring.HasReplicationSetChangedWithoutStateOrAddr(s.previousRing, rs) + s.previousRing = rs + return ringChanged, nil +} + +func (s *recalculateOwnedStreams) getTokenRangesForIngester() (ring.TokenRanges, error) { + ranges, err := s.ingestersRing.GetTokenRangesForInstance(s.ingesterID) + if err != nil { + if errors.Is(err, ring.ErrInstanceNotFound) { + return nil, nil + } + return nil, err + } + + return ranges, nil +} diff --git a/pkg/ingester/recalculate_owned_streams_test.go b/pkg/ingester/recalculate_owned_streams_test.go new file mode 100644 index 000000000000..6ac0f54f06e8 --- /dev/null +++ b/pkg/ingester/recalculate_owned_streams_test.go @@ -0,0 +1,154 @@ +package ingester + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/go-kit/log" + "github.com/grafana/dskit/ring" + "github.com/prometheus/common/model" + "github.com/prometheus/prometheus/model/labels" + "github.com/stretchr/testify/require" + + "github.com/grafana/loki/v3/pkg/runtime" + "github.com/grafana/loki/v3/pkg/validation" +) + +func Test_recalculateOwnedStreams_newRecalculateOwnedStreams(t *testing.T) { + mockInstancesSupplier := &mockTenantsSuplier{tenants: []*instance{}} + mockRing := newReadRingMock([]ring.InstanceDesc{ + {Addr: "test", Timestamp: time.Now().UnixNano(), State: ring.ACTIVE, Tokens: []uint32{1, 2, 3}}, + }, 0) + service := newRecalculateOwnedStreams(mockInstancesSupplier.get, "test", mockRing, 50*time.Millisecond, log.NewNopLogger()) + require.Equal(t, 0, mockRing.getAllHealthyCallsCount, "ring must be called only after service's start up") + ctx := context.Background() + require.NoError(t, service.StartAsync(ctx)) + require.NoError(t, service.AwaitRunning(ctx)) + require.Eventually(t, func() bool { + return mockRing.getAllHealthyCallsCount >= 2 + }, 1*time.Second, 50*time.Millisecond, "expected at least two runs of the iteration") +} + +func Test_recalculateOwnedStreams_recalculate(t *testing.T) { + tests := map[string]struct { + featureEnabled bool + expectedOwnedStreamCount int + expectedNotOwnedStreamCount int + }{ + "expected streams ownership to be recalculated": { + featureEnabled: true, + expectedOwnedStreamCount: 4, + expectedNotOwnedStreamCount: 3, + }, + "expected streams ownership recalculation to be skipped": { + featureEnabled: false, + expectedOwnedStreamCount: 7, + }, + } + for testName, testData := range tests { + t.Run(testName, func(t *testing.T) { + mockRing := &readRingMock{ + replicationSet: ring.ReplicationSet{ + Instances: []ring.InstanceDesc{{Addr: "ingester-0", Timestamp: time.Now().UnixNano(), State: ring.ACTIVE, Tokens: []uint32{100, 200, 300}}}, + }, + tokenRangesByIngester: map[string]ring.TokenRanges{ + // this ingester owns token ranges [50, 100] and [200, 300] + "ingester-0": {50, 100, 200, 300}, + }, + } + + limits, err := validation.NewOverrides(validation.Limits{ + MaxGlobalStreamsPerUser: 100, + UseOwnedStreamCount: testData.featureEnabled, + }, nil) + require.NoError(t, err) + limiter := NewLimiter(limits, NilMetrics, mockRing, 1) + + tenant, err := newInstance( + defaultConfig(), + defaultPeriodConfigs, + "tenant-a", + limiter, + runtime.DefaultTenantConfigs(), + noopWAL{}, + NilMetrics, + nil, + nil, + nil, + nil, + NewStreamRateCalculator(), + nil, + nil, + ) + require.NoError(t, err) + // not owned streams + createStream(t, tenant, 49) + createStream(t, tenant, 101) + createStream(t, tenant, 301) + + // owned streams + createStream(t, tenant, 50) + createStream(t, tenant, 60) + createStream(t, tenant, 100) + createStream(t, tenant, 250) + + require.Equal(t, 7, tenant.ownedStreamsSvc.ownedStreamCount) + require.Equal(t, 0, tenant.ownedStreamsSvc.notOwnedStreamCount) + + mockTenantsSupplier := &mockTenantsSuplier{tenants: []*instance{tenant}} + + service := newRecalculateOwnedStreams(mockTenantsSupplier.get, "ingester-0", mockRing, 50*time.Millisecond, log.NewNopLogger()) + + service.recalculate() + + require.Equal(t, testData.expectedOwnedStreamCount, tenant.ownedStreamsSvc.ownedStreamCount) + require.Equal(t, testData.expectedNotOwnedStreamCount, tenant.ownedStreamsSvc.notOwnedStreamCount) + }) + } + +} + +func Test_recalculateOwnedStreams_checkRingForChanges(t *testing.T) { + mockRing := &readRingMock{ + replicationSet: ring.ReplicationSet{ + Instances: []ring.InstanceDesc{{Addr: "ingester-0", Timestamp: time.Now().UnixNano(), State: ring.ACTIVE, Tokens: []uint32{100, 200, 300}}}, + }, + } + mockTenantsSupplier := &mockTenantsSuplier{tenants: []*instance{{}}} + service := newRecalculateOwnedStreams(mockTenantsSupplier.get, "ingester-0", mockRing, 50*time.Millisecond, log.NewNopLogger()) + + ringChanged, err := service.checkRingForChanges() + require.NoError(t, err) + require.True(t, ringChanged, "expected ring to be changed because it was not initialized yet") + + ringChanged, err = service.checkRingForChanges() + require.NoError(t, err) + require.False(t, ringChanged, "expected ring not to be changed because token ranges is not changed") + + anotherIngester := ring.InstanceDesc{Addr: "ingester-1", Timestamp: time.Now().UnixNano(), State: ring.ACTIVE, Tokens: []uint32{150, 250, 350}} + mockRing.replicationSet.Instances = append(mockRing.replicationSet.Instances, anotherIngester) + + ringChanged, err = service.checkRingForChanges() + require.NoError(t, err) + require.True(t, ringChanged) +} + +func createStream(t *testing.T, inst *instance, fingerprint int) { + lbls := labels.Labels{ + labels.Label{Name: "mock", Value: strconv.Itoa(fingerprint)}} + + _, _, err := inst.streams.LoadOrStoreNew(lbls.String(), func() (*stream, error) { + return inst.createStreamByFP(lbls, model.Fingerprint(fingerprint)) + }, nil) + require.NoError(t, err) +} + +type mockTenantsSuplier struct { + tenants []*instance +} + +func (m *mockTenantsSuplier) get() []*instance { + return m.tenants +} diff --git a/pkg/ingester/recovery_test.go b/pkg/ingester/recovery_test.go index 9176ff3c6ad2..4c5a4ce815d8 100644 --- a/pkg/ingester/recovery_test.go +++ b/pkg/ingester/recovery_test.go @@ -228,7 +228,9 @@ func TestSeriesRecoveryNoDuplicates(t *testing.T) { chunks: map[string][]chunk.Chunk{}, } - i, err := New(ingesterConfig, client.Config{}, store, limits, loki_runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + readRingMock := mockReadRingWithOneActiveIngester() + + i, err := New(ingesterConfig, client.Config{}, store, limits, loki_runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) mkSample := func(i int) *logproto.PushRequest { @@ -262,7 +264,7 @@ func TestSeriesRecoveryNoDuplicates(t *testing.T) { require.Equal(t, false, iter.Next()) // create a new ingester now - i, err = New(ingesterConfig, client.Config{}, store, limits, loki_runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil) + i, err = New(ingesterConfig, client.Config{}, store, limits, loki_runtime.DefaultTenantConfigs(), nil, writefailures.Cfg{}, constants.Loki, log.NewNopLogger(), nil, readRingMock) require.NoError(t, err) // recover the checkpointed series diff --git a/pkg/loki/modules.go b/pkg/loki/modules.go index 22cd46743ea2..c3904d1e681d 100644 --- a/pkg/loki/modules.go +++ b/pkg/loki/modules.go @@ -587,7 +587,7 @@ func (t *Loki) initIngester() (_ services.Service, err error) { level.Warn(util_log.Logger).Log("msg", "The config setting shutdown marker path is not set. The /ingester/prepare_shutdown endpoint won't work") } - t.Ingester, err = ingester.New(t.Cfg.Ingester, t.Cfg.IngesterClient, t.Store, t.Overrides, t.tenantConfigs, prometheus.DefaultRegisterer, t.Cfg.Distributor.WriteFailuresLogging, t.Cfg.MetricsNamespace, logger, t.UsageTracker) + t.Ingester, err = ingester.New(t.Cfg.Ingester, t.Cfg.IngesterClient, t.Store, t.Overrides, t.tenantConfigs, prometheus.DefaultRegisterer, t.Cfg.Distributor.WriteFailuresLogging, t.Cfg.MetricsNamespace, logger, t.UsageTracker, t.ring) if err != nil { return } diff --git a/pkg/querier/querier_mock_test.go b/pkg/querier/querier_mock_test.go index b70ba0cc3963..4d5d2f22c532 100644 --- a/pkg/querier/querier_mock_test.go +++ b/pkg/querier/querier_mock_test.go @@ -451,6 +451,22 @@ func (r *readRingMock) InstancesCount() int { return len(r.replicationSet.Instances) } +func (r *readRingMock) InstancesInZoneCount(_ string) int { + return len(r.replicationSet.Instances) +} + +func (r *readRingMock) InstancesWithTokensCount() int { + return len(r.replicationSet.Instances) +} + +func (r *readRingMock) InstancesWithTokensInZoneCount(_ string) int { + return len(r.replicationSet.Instances) +} + +func (r *readRingMock) ZonesCount() int { + return 1 +} + func (r *readRingMock) Subring(_ uint32, _ int) ring.ReadRing { return r } diff --git a/pkg/util/ring/ring_test.go b/pkg/util/ring/ring_test.go index 8e765d4ef6d4..fa285c0cb4ae 100644 --- a/pkg/util/ring/ring_test.go +++ b/pkg/util/ring/ring_test.go @@ -95,6 +95,44 @@ func (r *readRingMock) GetTokenRangesForInstance(_ string) (ring.TokenRanges, er return tr, nil } +func (r *readRingMock) InstancesInZoneCount(zone string) int { + count := 0 + for _, instance := range r.replicationSet.Instances { + if instance.Zone == zone { + count++ + } + } + return count +} + +func (r *readRingMock) InstancesWithTokensCount() int { + count := 0 + for _, instance := range r.replicationSet.Instances { + if len(instance.Tokens) > 0 { + count++ + } + } + return count +} + +func (r *readRingMock) InstancesWithTokensInZoneCount(zone string) int { + count := 0 + for _, instance := range r.replicationSet.Instances { + if len(instance.Tokens) > 0 && instance.Zone == zone { + count++ + } + } + return count +} + +func (r *readRingMock) ZonesCount() int { + uniqueZone := make(map[string]any) + for _, instance := range r.replicationSet.Instances { + uniqueZone[instance.Zone] = nil + } + return len(uniqueZone) +} + type readLifecyclerMock struct { mock.Mock addr string diff --git a/vendor/github.com/grafana/dskit/concurrency/runner.go b/vendor/github.com/grafana/dskit/concurrency/runner.go index 023be10d7a0a..fcc892997149 100644 --- a/vendor/github.com/grafana/dskit/concurrency/runner.go +++ b/vendor/github.com/grafana/dskit/concurrency/runner.go @@ -83,11 +83,25 @@ func CreateJobsFromStrings(values []string) []interface{} { } // ForEachJob runs the provided jobFunc for each job index in [0, jobs) up to concurrency concurrent workers. +// If the concurrency value is <= 0 all jobs will be executed in parallel. +// // The execution breaks on first error encountered. +// +// ForEachJob cancels the context.Context passed to each invocation of jobFunc before ForEachJob returns. func ForEachJob(ctx context.Context, jobs int, concurrency int, jobFunc func(ctx context.Context, idx int) error) error { if jobs == 0 { return nil } + if jobs == 1 { + // Honor the function contract, cancelling the context passed to the jobFunc once it completed. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + return jobFunc(ctx, 0) + } + if concurrency <= 0 { + concurrency = jobs + } // Initialise indexes with -1 so first Inc() returns index 0. indexes := atomic.NewInt64(-1) @@ -113,3 +127,35 @@ func ForEachJob(ctx context.Context, jobs int, concurrency int, jobFunc func(ctx // Wait until done (or context has canceled). return g.Wait() } + +// ForEachJobMergeResults is like ForEachJob but expects jobFunc to return a slice of results which are then +// merged with results from all jobs. This function returns no results if an error occurred running any jobFunc. +// +// ForEachJobMergeResults cancels the context.Context passed to each invocation of jobFunc before ForEachJobMergeResults returns. +func ForEachJobMergeResults[J any, R any](ctx context.Context, jobs []J, concurrency int, jobFunc func(ctx context.Context, job J) ([]R, error)) ([]R, error) { + var ( + resultsMx sync.Mutex + results = make([]R, 0, len(jobs)) // Assume at least 1 result per job. + ) + + err := ForEachJob(ctx, len(jobs), concurrency, func(ctx context.Context, idx int) error { + jobResult, jobErr := jobFunc(ctx, jobs[idx]) + if jobErr != nil { + return jobErr + } + + resultsMx.Lock() + results = append(results, jobResult...) + resultsMx.Unlock() + + return nil + }) + + if err != nil { + return nil, err + } + + // Given no error occurred, it means that all job results have already been collected + // and so it's safe to access results slice with no locking. + return results, nil +} diff --git a/vendor/github.com/grafana/dskit/grpcclient/backoff_retry.go b/vendor/github.com/grafana/dskit/grpcclient/backoff_retry.go index 21abbb786568..b0d7f9004f8d 100644 --- a/vendor/github.com/grafana/dskit/grpcclient/backoff_retry.go +++ b/vendor/github.com/grafana/dskit/grpcclient/backoff_retry.go @@ -2,6 +2,7 @@ package grpcclient import ( "context" + "errors" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -10,22 +11,27 @@ import ( "github.com/grafana/dskit/backoff" ) -// NewBackoffRetry gRPC middleware. -func NewBackoffRetry(cfg backoff.Config) grpc.UnaryClientInterceptor { +// NewRateLimitRetrier creates a UnaryClientInterceptor which retries with backoff +// the calls from invoker when the executed RPC is rate limited. +func NewRateLimitRetrier(cfg backoff.Config) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { backoff := backoff.New(ctx, cfg) + var err error for backoff.Ongoing() { - err := invoker(ctx, method, req, reply, cc, opts...) + err = invoker(ctx, method, req, reply, cc, opts...) if err == nil { return nil } + // Only ResourceExhausted statuses are handled as signals of being rate limited, + // following the implementation of package's RateLimiter interceptor. + // All other errors are propogated as-is upstream. if status.Code(err) != codes.ResourceExhausted { return err } backoff.Wait() } - return backoff.Err() + return errors.Join(err, backoff.Err()) } } diff --git a/vendor/github.com/grafana/dskit/grpcclient/grpcclient.go b/vendor/github.com/grafana/dskit/grpcclient/grpcclient.go index b171889d0a04..751899047154 100644 --- a/vendor/github.com/grafana/dskit/grpcclient/grpcclient.go +++ b/vendor/github.com/grafana/dskit/grpcclient/grpcclient.go @@ -108,7 +108,7 @@ func (cfg *Config) DialOption(unaryClientInterceptors []grpc.UnaryClientIntercep streamClientInterceptors = append(streamClientInterceptors, cfg.StreamMiddleware...) if cfg.BackoffOnRatelimits { - unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{NewBackoffRetry(cfg.BackoffConfig)}, unaryClientInterceptors...) + unaryClientInterceptors = append([]grpc.UnaryClientInterceptor{NewRateLimitRetrier(cfg.BackoffConfig)}, unaryClientInterceptors...) } if cfg.RateLimit > 0 { diff --git a/vendor/github.com/grafana/dskit/httpgrpc/httpgrpc.go b/vendor/github.com/grafana/dskit/httpgrpc/httpgrpc.go index e1f044d8650b..b755e2adceae 100644 --- a/vendor/github.com/grafana/dskit/httpgrpc/httpgrpc.go +++ b/vendor/github.com/grafana/dskit/httpgrpc/httpgrpc.go @@ -89,7 +89,9 @@ func WriteError(w http.ResponseWriter, err error) { func ToHeader(hs []*Header, header http.Header) { for _, h := range hs { - header[h.Key] = h.Values + // http.Header expects header to be stored in canonical form, + // otherwise they are inaccessible with Get() on a http.Header struct. + header[http.CanonicalHeaderKey(h.Key)] = h.Values } } diff --git a/vendor/github.com/grafana/dskit/instrument/instrument.go b/vendor/github.com/grafana/dskit/instrument/instrument.go index 4ea480b29d60..f54e49def308 100644 --- a/vendor/github.com/grafana/dskit/instrument/instrument.go +++ b/vendor/github.com/grafana/dskit/instrument/instrument.go @@ -75,7 +75,7 @@ func ObserveWithExemplar(ctx context.Context, histogram prometheus.Observer, sec if traceID, ok := tracing.ExtractSampledTraceID(ctx); ok { histogram.(prometheus.ExemplarObserver).ObserveWithExemplar( seconds, - prometheus.Labels{"traceID": traceID}, + prometheus.Labels{"trace_id": traceID, "traceID": traceID}, ) return } diff --git a/vendor/github.com/grafana/dskit/kv/memberlist/memberlist_client.go b/vendor/github.com/grafana/dskit/kv/memberlist/memberlist_client.go index 693964b5ad06..e8a94debe181 100644 --- a/vendor/github.com/grafana/dskit/kv/memberlist/memberlist_client.go +++ b/vendor/github.com/grafana/dskit/kv/memberlist/memberlist_client.go @@ -177,7 +177,7 @@ func (cfg *KVConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix string) { // "Defaults to hostname" -- memberlist sets it to hostname by default. f.StringVar(&cfg.NodeName, prefix+"memberlist.nodename", "", "Name of the node in memberlist cluster. Defaults to hostname.") // memberlist.DefaultLANConfig will put hostname here. f.BoolVar(&cfg.RandomizeNodeName, prefix+"memberlist.randomize-node-name", true, "Add random suffix to the node name.") - f.DurationVar(&cfg.StreamTimeout, prefix+"memberlist.stream-timeout", mlDefaults.TCPTimeout, "The timeout for establishing a connection with a remote node, and for read/write operations.") + f.DurationVar(&cfg.StreamTimeout, prefix+"memberlist.stream-timeout", 2*time.Second, "The timeout for establishing a connection with a remote node, and for read/write operations.") f.IntVar(&cfg.RetransmitMult, prefix+"memberlist.retransmit-factor", mlDefaults.RetransmitMult, "Multiplication factor used when sending out messages (factor * log(N+1)).") f.Var(&cfg.JoinMembers, prefix+"memberlist.join", "Other cluster members to join. Can be specified multiple times. It can be an IP, hostname or an entry specified in the DNS Service Discovery format.") f.DurationVar(&cfg.MinJoinBackoff, prefix+"memberlist.min-join-backoff", 1*time.Second, "Min backoff duration to join other cluster members.") diff --git a/vendor/github.com/grafana/dskit/kv/memberlist/mergeable.go b/vendor/github.com/grafana/dskit/kv/memberlist/mergeable.go index 2c02acfa468e..9833a858b476 100644 --- a/vendor/github.com/grafana/dskit/kv/memberlist/mergeable.go +++ b/vendor/github.com/grafana/dskit/kv/memberlist/mergeable.go @@ -35,17 +35,17 @@ type Mergeable interface { // used when doing CAS operation) Merge(other Mergeable, localCAS bool) (change Mergeable, error error) - // Describes the content of this mergeable value. Used by memberlist client to decide if + // MergeContent describes the content of this mergeable value. Used by memberlist client to decide if // one change-value can invalidate some other value, that was received previously. // Invalidation can happen only if output of MergeContent is a superset of some other MergeContent. MergeContent() []string - // Remove tombstones older than given limit from this mergeable. + // RemoveTombstones remove tombstones older than given limit from this mergeable. // If limit is zero time, remove all tombstones. Memberlist client calls this method with zero limit each // time when client is accessing value from the store. It can be used to hide tombstones from the clients. // Returns the total number of tombstones present and the number of removed tombstones by this invocation. RemoveTombstones(limit time.Time) (total, removed int) - // Clone should return a deep copy of the state. + // Clone returns a deep copy of the state. Clone() Mergeable } diff --git a/vendor/github.com/grafana/dskit/kv/metrics.go b/vendor/github.com/grafana/dskit/kv/metrics.go index 7361b8c41c78..954f06ed30b4 100644 --- a/vendor/github.com/grafana/dskit/kv/metrics.go +++ b/vendor/github.com/grafana/dskit/kv/metrics.go @@ -3,6 +3,7 @@ package kv import ( "context" "strconv" + "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -53,6 +54,10 @@ func newMetricsClient(backend string, c Client, reg prometheus.Registerer) Clien Name: "kv_request_duration_seconds", Help: "Time spent on kv store requests.", Buckets: prometheus.DefBuckets, + // Use defaults recommended by Prometheus for native histograms. + NativeHistogramBucketFactor: 1.1, + NativeHistogramMaxBucketNumber: 100, + NativeHistogramMinResetDuration: time.Hour, ConstLabels: prometheus.Labels{ "type": backend, }, diff --git a/vendor/github.com/grafana/dskit/middleware/grpc_instrumentation.go b/vendor/github.com/grafana/dskit/middleware/grpc_instrumentation.go index e4052b8ed05f..d15402ea484d 100644 --- a/vendor/github.com/grafana/dskit/middleware/grpc_instrumentation.go +++ b/vendor/github.com/grafana/dskit/middleware/grpc_instrumentation.go @@ -22,7 +22,20 @@ import ( ) func observe(ctx context.Context, hist *prometheus.HistogramVec, method string, err error, duration time.Duration, instrumentLabel instrumentationLabel) { - instrument.ObserveWithExemplar(ctx, hist.WithLabelValues(gRPC, method, instrumentLabel.getInstrumentationLabel(err), "false"), duration.Seconds()) + labelValues := []string{ + gRPC, + method, + instrumentLabel.getInstrumentationLabel(err), + "false", + "", // this is a placeholder for the tenant ID + } + labelValues = labelValues[:len(labelValues)-1] + + instrument.ObserveWithExemplar(ctx, hist.WithLabelValues(labelValues...), duration.Seconds()) + if tenantID, ok := instrumentLabel.perTenantInstrumentation.shouldInstrument(ctx); ok { + labelValues = append(labelValues, tenantID) + instrument.ObserveWithExemplar(ctx, instrumentLabel.perTenantDuration.WithLabelValues(labelValues...), duration.Seconds()) + } } // UnaryServerInstrumentInterceptor instruments gRPC requests for errors and latency. @@ -182,8 +195,17 @@ var ( } ) +func WithPerTenantInstrumentation(m *prometheus.HistogramVec, f PerTenantCallback) InstrumentationOption { + return func(instrumentationLabel *instrumentationLabel) { + instrumentationLabel.perTenantInstrumentation = f + instrumentationLabel.perTenantDuration = m + } +} + func applyInstrumentationOptions(maskHTTPStatuses bool, options ...InstrumentationOption) instrumentationLabel { - instrumentationLabel := instrumentationLabel{maskHTTPStatus: maskHTTPStatuses} + instrumentationLabel := instrumentationLabel{ + maskHTTPStatus: maskHTTPStatuses, + } for _, opt := range options { opt(&instrumentationLabel) } @@ -191,8 +213,10 @@ func applyInstrumentationOptions(maskHTTPStatuses bool, options ...Instrumentati } type instrumentationLabel struct { - reportGRPCStatus bool - maskHTTPStatus bool + reportGRPCStatus bool + maskHTTPStatus bool + perTenantInstrumentation PerTenantCallback + perTenantDuration *prometheus.HistogramVec } // getInstrumentationLabel converts an error into an error code string by applying the configurations diff --git a/vendor/github.com/grafana/dskit/middleware/grpc_logging.go b/vendor/github.com/grafana/dskit/middleware/grpc_logging.go index feab36474322..68a2ce037ee6 100644 --- a/vendor/github.com/grafana/dskit/middleware/grpc_logging.go +++ b/vendor/github.com/grafana/dskit/middleware/grpc_logging.go @@ -6,11 +6,12 @@ package middleware import ( "context" - "errors" + "fmt" "time" "github.com/go-kit/log" "github.com/go-kit/log/level" + "github.com/pkg/errors" dskit_log "github.com/grafana/dskit/log" @@ -24,16 +25,20 @@ const ( gRPC = "gRPC" ) -// An error can implement ShouldLog() to control whether GRPCServerLog will log. +// OptionalLogging is the interface that needs be implemented by an error that wants to control whether the log +// should be logged by GRPCServerLog. type OptionalLogging interface { - ShouldLog(ctx context.Context, duration time.Duration) bool + // ShouldLog returns whether the error should be logged and the reason. For example, if the error should be sampled + // return returned reason could be something like "sampled 1/10". The reason, if any, is used to decorate the error + // both in case the error should be logged or skipped. + ShouldLog(ctx context.Context) (bool, string) } type DoNotLogError struct{ Err error } -func (i DoNotLogError) Error() string { return i.Err.Error() } -func (i DoNotLogError) Unwrap() error { return i.Err } -func (i DoNotLogError) ShouldLog(_ context.Context, _ time.Duration) bool { return false } +func (i DoNotLogError) Error() string { return i.Err.Error() } +func (i DoNotLogError) Unwrap() error { return i.Err } +func (i DoNotLogError) ShouldLog(_ context.Context) (bool, string) { return false, "" } // GRPCServerLog logs grpc requests, errors, and latency. type GRPCServerLog struct { @@ -50,8 +55,13 @@ func (s GRPCServerLog) UnaryServerInterceptor(ctx context.Context, req interface if err == nil && s.DisableRequestSuccessLog { return resp, nil } - var optional OptionalLogging - if errors.As(err, &optional) && !optional.ShouldLog(ctx, time.Since(begin)) { + + // Honor sampled error logging. + keep, reason := shouldLog(ctx, err) + if reason != "" { + err = fmt.Errorf("%w (%s)", err, reason) + } + if !keep { return resp, err } @@ -91,3 +101,12 @@ func (s GRPCServerLog) StreamServerInterceptor(srv interface{}, ss grpc.ServerSt } return err } + +func shouldLog(ctx context.Context, err error) (bool, string) { + var optional OptionalLogging + if !errors.As(err, &optional) { + return true, "" + } + + return optional.ShouldLog(ctx) +} diff --git a/vendor/github.com/grafana/dskit/middleware/grpc_stats.go b/vendor/github.com/grafana/dskit/middleware/grpc_stats.go index 3d29d9baabbe..ec766d640acf 100644 --- a/vendor/github.com/grafana/dskit/middleware/grpc_stats.go +++ b/vendor/github.com/grafana/dskit/middleware/grpc_stats.go @@ -31,6 +31,7 @@ type contextKey int const ( contextKeyMethodName contextKey = 1 + contextKeyRouteName contextKey = 2 ) func (g *grpcStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { diff --git a/vendor/github.com/grafana/dskit/middleware/http_timeout.go b/vendor/github.com/grafana/dskit/middleware/http_timeout.go new file mode 100644 index 000000000000..15b1a3f2e92f --- /dev/null +++ b/vendor/github.com/grafana/dskit/middleware/http_timeout.go @@ -0,0 +1,50 @@ +package middleware + +import ( + "net/http" + "time" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" +) + +// NewTimeoutMiddleware returns a new timeout middleware that returns a 503 Service Unavailable +// using the http.TimeoutHandler. Note also that the middleware disables the http server write timeout +// to ensure the two timeouts don't conflict. We disable the server write timeout b/c it's behavior may +// be unintuitive. See below. +// +// Server.WriteTimeout: +// - does not cancel context and instead allows the request to go until the next write. in practice this +// means that an http server with a write timeout of 10s may go for significantly longer +// - closes the tcp connection on the next write after the timeout has elapsed instead of sending a +// meaningful http response +// - allows streaming of http response back to caller +// +// http.TimeoutHandler +// - cancels context allowing downstream code to abandon the request +// - returns a 503 Service Unavailable with the provided message +// - buffers response in memory which may be undesirable for large responses +func NewTimeoutMiddleware(dt time.Duration, msg string, log log.Logger) Func { + return func(next http.Handler) http.Handler { + return &timeoutHandler{ + log: log, + handler: http.TimeoutHandler(next, dt, msg), + } + } +} + +type timeoutHandler struct { + log log.Logger + handler http.Handler +} + +func (t timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + rc := http.NewResponseController(w) + // setting the write deadline to the zero time disables it + err := rc.SetWriteDeadline(time.Time{}) + if err != nil { + level.Warn(t.log).Log("msg", "failed to set write deadline in timeout handler. server WriteTimeout is still enforced", "err", err) + } + + t.handler.ServeHTTP(w, r) +} diff --git a/vendor/github.com/grafana/dskit/middleware/http_tracing.go b/vendor/github.com/grafana/dskit/middleware/http_tracing.go index 901970a4a6b1..d75535ebe38c 100644 --- a/vendor/github.com/grafana/dskit/middleware/http_tracing.go +++ b/vendor/github.com/grafana/dskit/middleware/http_tracing.go @@ -24,14 +24,13 @@ var _ = nethttp.MWURLTagFunc // Tracer is a middleware which traces incoming requests. type Tracer struct { - RouteMatcher RouteMatcher - SourceIPs *SourceIPExtractor + SourceIPs *SourceIPExtractor } // Wrap implements Interface func (t Tracer) Wrap(next http.Handler) http.Handler { options := []nethttp.MWOption{ - nethttp.OperationNameFunc(makeHTTPOperationNameFunc(t.RouteMatcher)), + nethttp.OperationNameFunc(httpOperationNameFunc), nethttp.MWSpanObserver(func(sp opentracing.Span, r *http.Request) { // add a tag with the client's user agent to the span userAgent := r.Header.Get("User-Agent") @@ -130,11 +129,9 @@ func HTTPGRPCTracingInterceptor(router *mux.Router) grpc.UnaryServerInterceptor } } -func makeHTTPOperationNameFunc(routeMatcher RouteMatcher) func(r *http.Request) string { - return func(r *http.Request) string { - routeName := getRouteName(routeMatcher, r) - return getOperationName(routeName, r) - } +func httpOperationNameFunc(r *http.Request) string { + routeName := ExtractRouteName(r.Context()) + return getOperationName(routeName, r) } func getOperationName(routeName string, r *http.Request) string { diff --git a/vendor/github.com/grafana/dskit/middleware/instrument.go b/vendor/github.com/grafana/dskit/middleware/instrument.go index e5ae9c53c98c..9813077ce6c2 100644 --- a/vendor/github.com/grafana/dskit/middleware/instrument.go +++ b/vendor/github.com/grafana/dskit/middleware/instrument.go @@ -5,9 +5,9 @@ package middleware import ( + "context" "io" "net/http" - "regexp" "strconv" "strings" @@ -28,13 +28,28 @@ type RouteMatcher interface { Match(*http.Request, *mux.RouteMatch) bool } +// PerTenantCallback is a function that returns a tenant ID for a given request. When the returned tenant ID is not empty, it is used to label the duration histogram. +type PerTenantCallback func(context.Context) string + +func (f PerTenantCallback) shouldInstrument(ctx context.Context) (string, bool) { + if f == nil { + return "", false + } + tenantID := f(ctx) + if tenantID == "" { + return "", false + } + return tenantID, true +} + // Instrument is a Middleware which records timings for every HTTP request type Instrument struct { - RouteMatcher RouteMatcher - Duration *prometheus.HistogramVec - RequestBodySize *prometheus.HistogramVec - ResponseBodySize *prometheus.HistogramVec - InflightRequests *prometheus.GaugeVec + Duration *prometheus.HistogramVec + PerTenantDuration *prometheus.HistogramVec + PerTenantCallback PerTenantCallback + RequestBodySize *prometheus.HistogramVec + ResponseBodySize *prometheus.HistogramVec + InflightRequests *prometheus.GaugeVec } // IsWSHandshakeRequest returns true if the given request is a websocket handshake request. @@ -77,7 +92,19 @@ func (i Instrument) Wrap(next http.Handler) http.Handler { i.RequestBodySize.WithLabelValues(r.Method, route).Observe(float64(rBody.read)) i.ResponseBodySize.WithLabelValues(r.Method, route).Observe(float64(respMetrics.Written)) - instrument.ObserveWithExemplar(r.Context(), i.Duration.WithLabelValues(r.Method, route, strconv.Itoa(respMetrics.Code), isWS), respMetrics.Duration.Seconds()) + labelValues := []string{ + r.Method, + route, + strconv.Itoa(respMetrics.Code), + isWS, + "", // this is a placeholder for the tenant ID + } + labelValues = labelValues[:len(labelValues)-1] + instrument.ObserveWithExemplar(r.Context(), i.Duration.WithLabelValues(labelValues...), respMetrics.Duration.Seconds()) + if tenantID, ok := i.PerTenantCallback.shouldInstrument(r.Context()); ok { + labelValues = append(labelValues, tenantID) + instrument.ObserveWithExemplar(r.Context(), i.PerTenantDuration.WithLabelValues(labelValues...), respMetrics.Duration.Seconds()) + } }) } @@ -91,7 +118,7 @@ func (i Instrument) Wrap(next http.Handler) http.Handler { // We do all this as we do not wish to emit high cardinality labels to // prometheus. func (i Instrument) getRouteName(r *http.Request) string { - route := getRouteName(i.RouteMatcher, r) + route := ExtractRouteName(r.Context()) if route == "" { route = "other" } @@ -99,53 +126,6 @@ func (i Instrument) getRouteName(r *http.Request) string { return route } -func getRouteName(routeMatcher RouteMatcher, r *http.Request) string { - var routeMatch mux.RouteMatch - if routeMatcher == nil || !routeMatcher.Match(r, &routeMatch) { - return "" - } - - if routeMatch.MatchErr == mux.ErrNotFound { - return "notfound" - } - - if routeMatch.Route == nil { - return "" - } - - if name := routeMatch.Route.GetName(); name != "" { - return name - } - - tmpl, err := routeMatch.Route.GetPathTemplate() - if err == nil { - return MakeLabelValue(tmpl) - } - - return "" -} - -var invalidChars = regexp.MustCompile(`[^a-zA-Z0-9]+`) - -// MakeLabelValue converts a Gorilla mux path to a string suitable for use in -// a Prometheus label value. -func MakeLabelValue(path string) string { - // Convert non-alnums to underscores. - result := invalidChars.ReplaceAllString(path, "_") - - // Trim leading and trailing underscores. - result = strings.Trim(result, "_") - - // Make it all lowercase - result = strings.ToLower(result) - - // Special case. - if result == "" { - result = "root" - } - return result -} - type reqBody struct { b io.ReadCloser read int64 diff --git a/vendor/github.com/grafana/dskit/middleware/logging.go b/vendor/github.com/grafana/dskit/middleware/logging.go index aeb15cc6b63a..fe00d3a82846 100644 --- a/vendor/github.com/grafana/dskit/middleware/logging.go +++ b/vendor/github.com/grafana/dskit/middleware/logging.go @@ -58,7 +58,7 @@ func (l Log) logWithRequest(r *http.Request) log.Logger { localLog := l.Log traceID, ok := tracing.ExtractTraceID(r.Context()) if ok { - localLog = log.With(localLog, "traceID", traceID) + localLog = log.With(localLog, "trace_id", traceID) } if l.SourceIPs != nil { diff --git a/vendor/github.com/grafana/dskit/middleware/route_injector.go b/vendor/github.com/grafana/dskit/middleware/route_injector.go new file mode 100644 index 000000000000..7b275f74f756 --- /dev/null +++ b/vendor/github.com/grafana/dskit/middleware/route_injector.go @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: AGPL-3.0-only + +package middleware + +import ( + "context" + "net/http" + "regexp" + "strings" + + "github.com/gorilla/mux" +) + +// RouteInjector is a middleware that injects the route name for the current request into the request context. +// +// The route name can be retrieved by calling ExtractRouteName. +type RouteInjector struct { + RouteMatcher RouteMatcher +} + +func (i RouteInjector) Wrap(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + routeName := getRouteName(i.RouteMatcher, r) + handler.ServeHTTP(w, WithRouteName(r, routeName)) + }) +} + +// WithRouteName annotates r's context with the provided route name. +// +// The provided value must be suitable to use as a Prometheus label value. +// +// This method should generally only be used in tests: in production code, use RouteInjector instead. +func WithRouteName(r *http.Request, routeName string) *http.Request { + ctx := context.WithValue(r.Context(), contextKeyRouteName, routeName) + return r.WithContext(ctx) +} + +// ExtractRouteName returns the route name associated with this request that was previously injected by the +// RouteInjector middleware or WithRouteName. +// +// This is the same route name used for trace and metric names, and is already suitable for use as a Prometheus label +// value. +func ExtractRouteName(ctx context.Context) string { + routeName, ok := ctx.Value(contextKeyRouteName).(string) + if !ok { + return "" + } + + return routeName +} + +func getRouteName(routeMatcher RouteMatcher, r *http.Request) string { + var routeMatch mux.RouteMatch + if routeMatcher == nil || !routeMatcher.Match(r, &routeMatch) { + return "" + } + + if routeMatch.MatchErr == mux.ErrNotFound { + return "notfound" + } + + if routeMatch.Route == nil { + return "" + } + + if name := routeMatch.Route.GetName(); name != "" { + return name + } + + tmpl, err := routeMatch.Route.GetPathTemplate() + if err == nil { + return MakeLabelValue(tmpl) + } + + return "" +} + +var invalidChars = regexp.MustCompile(`[^a-zA-Z0-9]+`) + +// MakeLabelValue converts a Gorilla mux path to a string suitable for use in +// a Prometheus label value. +func MakeLabelValue(path string) string { + // Convert non-alnums to underscores. + result := invalidChars.ReplaceAllString(path, "_") + + // Trim leading and trailing underscores. + result = strings.Trim(result, "_") + + // Make it all lowercase + result = strings.ToLower(result) + + // Special case. + if result == "" { + result = "root" + } + return result +} diff --git a/vendor/github.com/grafana/dskit/middleware/source_ips.go b/vendor/github.com/grafana/dskit/middleware/source_ips.go index 7c035ddbf47e..d08797abb09b 100644 --- a/vendor/github.com/grafana/dskit/middleware/source_ips.go +++ b/vendor/github.com/grafana/dskit/middleware/source_ips.go @@ -18,6 +18,9 @@ var ( // De-facto standard header keys. xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") xRealIP = http.CanonicalHeaderKey("X-Real-IP") + // Allows to extract the host from the X-Forwarded-For header. + // Will strip out any spaces or double quote surrounding host. + xForwardedForRegex = regexp.MustCompile(`(?: *"?([^,]+)"? *)`) ) var ( @@ -25,9 +28,9 @@ var ( // existing use of X-Forwarded-* headers. // e.g. Forwarded: for=192.0.2.60;proto=https;by=203.0.113.43 forwarded = http.CanonicalHeaderKey("Forwarded") - // Allows for a sub-match of the first value after 'for=' to the next - // comma, semi-colon or space. The match is case-insensitive. - forRegex = regexp.MustCompile(`(?i)(?:for=)([^(;|,| )]+)`) + // Allows to extract the host from the for clause of the Forwarded header. + // Will strip out any spaces or double quote surrounding host. + forwardedRegex = regexp.MustCompile(`(?i)(?:for=)(?: *"?([^;,]+)"? *)`) ) // SourceIPExtractor extracts the source IPs from a HTTP request @@ -37,10 +40,12 @@ type SourceIPExtractor struct { // A regex that extracts the IP address from the header. // It should contain at least one capturing group the first of which will be returned. regex *regexp.Regexp + // A boolean to choose if we should return all found IP or just first match + extractAllHosts bool } // NewSourceIPs creates a new SourceIPs -func NewSourceIPs(header, regex string) (*SourceIPExtractor, error) { +func NewSourceIPs(header, regex string, extractAllHosts bool) (*SourceIPExtractor, error) { if (header == "" && regex != "") || (header != "" && regex == "") { return nil, fmt.Errorf("either both a header field and a regex have to be given or neither") } @@ -50,8 +55,9 @@ func NewSourceIPs(header, regex string) (*SourceIPExtractor, error) { } return &SourceIPExtractor{ - header: header, - regex: re, + header: header, + regex: re, + extractAllHosts: extractAllHosts, }, nil } @@ -72,7 +78,15 @@ func extractHost(address string) string { // Get returns any source addresses we can find in the request, comma-separated func (sips SourceIPExtractor) Get(req *http.Request) string { - fwd := extractHost(sips.getIP(req)) + hosts := []string{} + + // Remove port informations from extracted address + for _, addr := range sips.getIP(req) { + hosts = append(hosts, extractHost(addr)) + } + + fwd := strings.Join(hosts, ", ") + if fwd == "" { if req.RemoteAddr == "" { return "" @@ -94,52 +108,45 @@ func (sips SourceIPExtractor) Get(req *http.Request) string { // getIP retrieves the IP from the RFC7239 Forwarded headers, // X-Real-IP and X-Forwarded-For (in that order) or from the // custom regex. -func (sips SourceIPExtractor) getIP(r *http.Request) string { - var addr string +func (sips SourceIPExtractor) getIP(r *http.Request) []string { + var addrs = []string{} // Use the custom regex only if it was setup if sips.header != "" { hdr := r.Header.Get(sips.header) if hdr == "" { - return "" - } - allMatches := sips.regex.FindAllStringSubmatch(hdr, 1) - if len(allMatches) == 0 { - return "" - } - firstMatch := allMatches[0] - // Check there is at least 1 submatch - if len(firstMatch) < 2 { - return "" + return addrs } - return firstMatch[1] - } - if fwd := r.Header.Get(forwarded); fwd != "" { - // match should contain at least two elements if the protocol was - // specified in the Forwarded header. The first element will always be - // the 'for=' capture, which we ignore. In the case of multiple IP - // addresses (for=8.8.8.8, 8.8.4.4,172.16.1.20 is valid) we only - // extract the first, which should be the client IP. - if match := forRegex.FindStringSubmatch(fwd); len(match) > 1 { - // IPv6 addresses in Forwarded headers are quoted-strings. We strip - // these quotes. - addr = strings.Trim(match[1], `"`) - } + addrs = sips.extractHeader(hdr, sips.regex) + } else if fwd := r.Header.Get(forwarded); fwd != "" { + addrs = sips.extractHeader(fwd, forwardedRegex) } else if fwd := r.Header.Get(xRealIP); fwd != "" { // X-Real-IP should only contain one IP address (the client making the // request). - addr = fwd + addrs = append([]string{}, fwd) } else if fwd := strings.ReplaceAll(r.Header.Get(xForwardedFor), " ", ""); fwd != "" { - // Only grab the first (client) address. Note that '192.168.0.1, - // 10.1.1.1' is a valid key for X-Forwarded-For where addresses after - // the first may represent forwarding proxies earlier in the chain. - s := strings.Index(fwd, ",") - if s == -1 { - s = len(fwd) + addrs = sips.extractHeader(fwd, xForwardedForRegex) + } + + return addrs +} + +// extractHeader is a toolbox function that will parse a header content with a regex and return a list +// of all matching groups as string. +func (sips SourceIPExtractor) extractHeader(header string, regex *regexp.Regexp) []string { + var addrs = []string{} + + if allMatches := regex.FindAllStringSubmatch(header, -1); len(allMatches) > 0 { + for _, match := range allMatches { + if len(match) > 1 { + addrs = append(addrs, match[1]) + } + if !sips.extractAllHosts { + break + } } - addr = fwd[:s] } - return addr + return addrs } diff --git a/vendor/github.com/grafana/dskit/multierror/multierror.go b/vendor/github.com/grafana/dskit/multierror/multierror.go index 68b73e201c1e..d48b76e6f1a1 100644 --- a/vendor/github.com/grafana/dskit/multierror/multierror.go +++ b/vendor/github.com/grafana/dskit/multierror/multierror.go @@ -5,7 +5,6 @@ package multierror import ( "bytes" - "errors" "fmt" ) @@ -62,14 +61,6 @@ func (es nonNilMultiError) Error() string { return buf.String() } -// Is attempts to match the provided error against errors in the error list. -// -// This function allows errors.Is to traverse the values stored in the MultiError. -func (es nonNilMultiError) Is(target error) bool { - for _, err := range es { - if errors.Is(err, target) { - return true - } - } - return false +func (es nonNilMultiError) Unwrap() []error { + return es } diff --git a/vendor/github.com/grafana/dskit/ring/batch.go b/vendor/github.com/grafana/dskit/ring/batch.go index 7781fe67a5ae..f982bd6c68c3 100644 --- a/vendor/github.com/grafana/dskit/ring/batch.go +++ b/vendor/github.com/grafana/dskit/ring/batch.go @@ -49,9 +49,26 @@ func isHTTPStatus4xx(err error) bool { return code/100 == 4 } +// DoBatchRing defines the interface required by a ring implementation to use DoBatch() and DoBatchWithOptions(). +type DoBatchRing interface { + // Get returns a ReplicationSet containing the instances to which the input key should be sharded to + // for the input Operation. + // + // The input buffers may be referenced in the returned ReplicationSet. This means that it's unsafe to call + // Get() multiple times passing the same buffers if ReplicationSet is retained between two different Get() + // calls. In this cas, you can pass nil buffers. + Get(key uint32, op Operation, bufInstances []InstanceDesc, bufStrings1, bufStrings2 []string) (ReplicationSet, error) + + // ReplicationFactor returns the number of instances each key is expected to be sharded to. + ReplicationFactor() int + + // InstancesCount returns the number of instances in the ring eligible to get any key sharded to. + InstancesCount() int +} + // DoBatch is a deprecated version of DoBatchWithOptions where grpc errors containing status codes 4xx are treated as client errors. // Deprecated. Use DoBatchWithOptions instead. -func DoBatch(ctx context.Context, op Operation, r ReadRing, keys []uint32, callback func(InstanceDesc, []int) error, cleanup func()) error { +func DoBatch(ctx context.Context, op Operation, r DoBatchRing, keys []uint32, callback func(InstanceDesc, []int) error, cleanup func()) error { return DoBatchWithOptions(ctx, op, r, keys, callback, DoBatchOptions{ Cleanup: cleanup, IsClientError: isHTTPStatus4xx, @@ -94,14 +111,14 @@ func (o *DoBatchOptions) replaceZeroValuesWithDefaults() { // See comments on DoBatchOptions for available options for this call. // // Not implemented as a method on Ring, so we can test separately. -func DoBatchWithOptions(ctx context.Context, op Operation, r ReadRing, keys []uint32, callback func(InstanceDesc, []int) error, o DoBatchOptions) error { +func DoBatchWithOptions(ctx context.Context, op Operation, r DoBatchRing, keys []uint32, callback func(InstanceDesc, []int) error, o DoBatchOptions) error { o.replaceZeroValuesWithDefaults() if r.InstancesCount() <= 0 { o.Cleanup() return fmt.Errorf("DoBatch: InstancesCount <= 0") } - expectedTrackers := len(keys) * (r.ReplicationFactor() + 1) / r.InstancesCount() + expectedTrackersPerInstance := len(keys) * (r.ReplicationFactor() + 1) / r.InstancesCount() itemTrackers := make([]itemTracker, len(keys)) instances := make(map[string]instance, r.InstancesCount()) @@ -132,8 +149,8 @@ func DoBatchWithOptions(ctx context.Context, op Operation, r ReadRing, keys []ui for _, desc := range replicationSet.Instances { curr, found := instances[desc.Addr] if !found { - curr.itemTrackers = make([]*itemTracker, 0, expectedTrackers) - curr.indexes = make([]int, 0, expectedTrackers) + curr.itemTrackers = make([]*itemTracker, 0, expectedTrackersPerInstance) + curr.indexes = make([]int, 0, expectedTrackersPerInstance) } instances[desc.Addr] = instance{ desc: desc, diff --git a/vendor/github.com/grafana/dskit/ring/model.go b/vendor/github.com/grafana/dskit/ring/model.go index 956dbe0cf422..334b027d0f8b 100644 --- a/vendor/github.com/grafana/dskit/ring/model.go +++ b/vendor/github.com/grafana/dskit/ring/model.go @@ -21,6 +21,13 @@ func (ts ByAddr) Len() int { return len(ts) } func (ts ByAddr) Swap(i, j int) { ts[i], ts[j] = ts[j], ts[i] } func (ts ByAddr) Less(i, j int) bool { return ts[i].Addr < ts[j].Addr } +// ByID is a sortable list of InstanceDesc. +type ByID []InstanceDesc + +func (ts ByID) Len() int { return len(ts) } +func (ts ByID) Swap(i, j int) { ts[i], ts[j] = ts[j], ts[i] } +func (ts ByID) Less(i, j int) bool { return ts[i].Id < ts[j].Id } + // ProtoDescFactory makes new Descs func ProtoDescFactory() proto.Message { return NewDesc() @@ -195,7 +202,6 @@ func (d *Desc) mergeWithTime(mergeable memberlist.Mergeable, localCAS bool, now other, ok := mergeable.(*Desc) if !ok { - // This method only deals with non-nil rings. return nil, fmt.Errorf("expected *ring.Desc, got %T", mergeable) } @@ -512,6 +518,40 @@ func (d *Desc) getOldestRegisteredTimestamp() int64 { return result } +func (d *Desc) instancesWithTokensCount() int { + count := 0 + if d != nil { + for _, ingester := range d.Ingesters { + if len(ingester.Tokens) > 0 { + count++ + } + } + } + return count +} + +func (d *Desc) instancesCountPerZone() map[string]int { + instancesCountPerZone := map[string]int{} + if d != nil { + for _, ingester := range d.Ingesters { + instancesCountPerZone[ingester.Zone]++ + } + } + return instancesCountPerZone +} + +func (d *Desc) instancesWithTokensCountPerZone() map[string]int { + instancesCountPerZone := map[string]int{} + if d != nil { + for _, ingester := range d.Ingesters { + if len(ingester.Tokens) > 0 { + instancesCountPerZone[ingester.Zone]++ + } + } + } + return instancesCountPerZone +} + type CompareResult int // CompareResult responses diff --git a/vendor/github.com/grafana/dskit/ring/partition_instance_lifecycler.go b/vendor/github.com/grafana/dskit/ring/partition_instance_lifecycler.go new file mode 100644 index 000000000000..09fef7223370 --- /dev/null +++ b/vendor/github.com/grafana/dskit/ring/partition_instance_lifecycler.go @@ -0,0 +1,412 @@ +package ring + +import ( + "context" + "time" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "go.uber.org/atomic" + + "github.com/grafana/dskit/kv" + "github.com/grafana/dskit/services" +) + +var ( + ErrPartitionDoesNotExist = errors.New("the partition does not exist") + ErrPartitionStateMismatch = errors.New("the partition state does not match the expected one") + ErrPartitionStateChangeNotAllowed = errors.New("partition state change not allowed") + + allowedPartitionStateChanges = map[PartitionState][]PartitionState{ + PartitionPending: {PartitionActive, PartitionInactive}, + PartitionActive: {PartitionInactive}, + PartitionInactive: {PartitionActive}, + } +) + +type PartitionInstanceLifecyclerConfig struct { + // PartitionID is the ID of the partition managed by the lifecycler. + PartitionID int32 + + // InstanceID is the ID of the instance managed by the lifecycler. + InstanceID string + + // WaitOwnersCountOnPending is the minimum number of owners to wait before switching a + // PENDING partition to ACTIVE. + WaitOwnersCountOnPending int + + // WaitOwnersDurationOnPending is how long each owner should have been added to the + // partition before it's considered eligible for the WaitOwnersCountOnPending count. + WaitOwnersDurationOnPending time.Duration + + // DeleteInactivePartitionAfterDuration is how long the lifecycler should wait before + // deleting inactive partitions with no owners. Inactive partitions are never removed + // if this value is 0. + DeleteInactivePartitionAfterDuration time.Duration + + // PollingInterval is the internal polling interval. This setting is useful to let + // upstream projects to lower it in unit tests. + PollingInterval time.Duration +} + +// PartitionInstanceLifecycler is responsible to manage the lifecycle of a single +// partition and partition owner in the ring. +type PartitionInstanceLifecycler struct { + *services.BasicService + + // These values are initialised at startup, and never change. + cfg PartitionInstanceLifecyclerConfig + ringName string + ringKey string + store kv.Client + logger log.Logger + + // Channel used to execute logic within the lifecycler loop. + actorChan chan func() + + // Whether the partitions should be created on startup if it doesn't exist yet. + createPartitionOnStartup *atomic.Bool + + // Whether the lifecycler should remove the partition owner (identified by instance ID) on shutdown. + removeOwnerOnShutdown *atomic.Bool + + // Metrics. + reconcilesTotal *prometheus.CounterVec + reconcilesFailedTotal *prometheus.CounterVec +} + +func NewPartitionInstanceLifecycler(cfg PartitionInstanceLifecyclerConfig, ringName, ringKey string, store kv.Client, logger log.Logger, reg prometheus.Registerer) *PartitionInstanceLifecycler { + if cfg.PollingInterval == 0 { + cfg.PollingInterval = 5 * time.Second + } + + l := &PartitionInstanceLifecycler{ + cfg: cfg, + ringName: ringName, + ringKey: ringKey, + store: store, + logger: log.With(logger, "ring", ringName), + actorChan: make(chan func()), + createPartitionOnStartup: atomic.NewBool(true), + removeOwnerOnShutdown: atomic.NewBool(false), + reconcilesTotal: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Name: "partition_ring_lifecycler_reconciles_total", + Help: "Total number of reconciliations started.", + ConstLabels: map[string]string{"name": ringName}, + }, []string{"type"}), + reconcilesFailedTotal: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Name: "partition_ring_lifecycler_reconciles_failed_total", + Help: "Total number of reconciliations failed.", + ConstLabels: map[string]string{"name": ringName}, + }, []string{"type"}), + } + + l.BasicService = services.NewBasicService(l.starting, l.running, l.stopping) + + return l +} + +// CreatePartitionOnStartup returns whether the lifecycle creates the partition on startup +// if it doesn't exist. +func (l *PartitionInstanceLifecycler) CreatePartitionOnStartup() bool { + return l.createPartitionOnStartup.Load() +} + +// SetCreatePartitionOnStartup sets whether the lifecycler should create the partition on +// startup if it doesn't exist. +func (l *PartitionInstanceLifecycler) SetCreatePartitionOnStartup(create bool) { + l.createPartitionOnStartup.Store(create) +} + +// RemoveOwnerOnShutdown returns whether the lifecycler has been configured to remove the partition +// owner on shutdown. +func (l *PartitionInstanceLifecycler) RemoveOwnerOnShutdown() bool { + return l.removeOwnerOnShutdown.Load() +} + +// SetRemoveOwnerOnShutdown sets whether the lifecycler should remove the partition owner on shutdown. +func (l *PartitionInstanceLifecycler) SetRemoveOwnerOnShutdown(remove bool) { + l.removeOwnerOnShutdown.Store(remove) +} + +// GetPartitionState returns the current state of the partition, and the timestamp when the state was +// changed the last time. +func (l *PartitionInstanceLifecycler) GetPartitionState(ctx context.Context) (PartitionState, time.Time, error) { + ring, err := l.getRing(ctx) + if err != nil { + return PartitionUnknown, time.Time{}, err + } + + partition, exists := ring.Partitions[l.cfg.PartitionID] + if !exists { + return PartitionUnknown, time.Time{}, ErrPartitionDoesNotExist + } + + return partition.GetState(), partition.GetStateTime(), nil +} + +// ChangePartitionState changes the partition state to toState. +// This function returns ErrPartitionDoesNotExist if the partition doesn't exist, +// and ErrPartitionStateChangeNotAllowed if the state change is not allowed. +func (l *PartitionInstanceLifecycler) ChangePartitionState(ctx context.Context, toState PartitionState) error { + return l.runOnLifecyclerLoop(func() error { + err := l.updateRing(ctx, func(ring *PartitionRingDesc) (bool, error) { + return changePartitionState(ring, l.cfg.PartitionID, toState) + }) + + if err != nil { + level.Warn(l.logger).Log("msg", "failed to change partition state", "partition", l.cfg.PartitionID, "to_state", toState, "err", err) + } + + return err + }) +} + +func (l *PartitionInstanceLifecycler) starting(ctx context.Context) error { + if l.CreatePartitionOnStartup() { + return errors.Wrap(l.createPartitionAndRegisterOwner(ctx), "create partition and register owner") + } + + return errors.Wrap(l.waitPartitionAndRegisterOwner(ctx), "wait partition and register owner") +} + +func (l *PartitionInstanceLifecycler) running(ctx context.Context) error { + reconcile := func() { + l.reconcileOwnedPartition(ctx, time.Now()) + l.reconcileOtherPartitions(ctx, time.Now()) + } + + // Run a reconciliation as soon as the lifecycler, in order to not having to wait for the 1st timer tick. + reconcile() + + reconcileTicker := time.NewTicker(l.cfg.PollingInterval) + defer reconcileTicker.Stop() + + for { + select { + case <-reconcileTicker.C: + reconcile() + + case f := <-l.actorChan: + f() + + case <-ctx.Done(): + return nil + } + } +} + +func (l *PartitionInstanceLifecycler) stopping(_ error) error { + level.Info(l.logger).Log("msg", "partition ring lifecycler is shutting down", "ring", l.ringName) + + // Remove the instance from partition owners, if configured to do so. + if l.RemoveOwnerOnShutdown() { + err := l.updateRing(context.Background(), func(ring *PartitionRingDesc) (bool, error) { + return ring.RemoveOwner(l.cfg.InstanceID), nil + }) + + if err != nil { + level.Error(l.logger).Log("msg", "failed to remove instance from partition owners on shutdown", "instance", l.cfg.InstanceID, "partition", l.cfg.PartitionID, "err", err) + } else { + level.Info(l.logger).Log("msg", "instance removed from partition owners", "instance", l.cfg.InstanceID, "partition", l.cfg.PartitionID) + } + } + + return nil +} + +// runOnLifecyclerLoop runs fn within the lifecycler loop. +func (l *PartitionInstanceLifecycler) runOnLifecyclerLoop(fn func() error) error { + sc := l.ServiceContext() + if sc == nil { + return errors.New("lifecycler not running") + } + + errCh := make(chan error) + wrappedFn := func() { + errCh <- fn() + } + + select { + case <-sc.Done(): + return errors.New("lifecycler not running") + case l.actorChan <- wrappedFn: + return <-errCh + } +} + +func (l *PartitionInstanceLifecycler) getRing(ctx context.Context) (*PartitionRingDesc, error) { + in, err := l.store.Get(ctx, l.ringKey) + if err != nil { + return nil, err + } + + return GetOrCreatePartitionRingDesc(in), nil +} + +func (l *PartitionInstanceLifecycler) updateRing(ctx context.Context, update func(ring *PartitionRingDesc) (bool, error)) error { + return l.store.CAS(ctx, l.ringKey, func(in interface{}) (out interface{}, retry bool, err error) { + ringDesc := GetOrCreatePartitionRingDesc(in) + + if changed, err := update(ringDesc); err != nil { + return nil, false, err + } else if !changed { + return nil, false, nil + } + + return ringDesc, true, nil + }) +} + +func (l *PartitionInstanceLifecycler) createPartitionAndRegisterOwner(ctx context.Context) error { + return l.updateRing(ctx, func(ring *PartitionRingDesc) (bool, error) { + now := time.Now() + changed := false + + partitionDesc, exists := ring.Partitions[l.cfg.PartitionID] + if exists { + level.Info(l.logger).Log("msg", "partition found in the ring", "partition", l.cfg.PartitionID, "state", partitionDesc.GetState(), "state_timestamp", partitionDesc.GetState().String(), "tokens", len(partitionDesc.GetTokens())) + } else { + level.Info(l.logger).Log("msg", "partition not found in the ring", "partition", l.cfg.PartitionID) + } + + if !exists { + // The partition doesn't exist, so we create a new one. A new partition should always be created + // in PENDING state. + ring.AddPartition(l.cfg.PartitionID, PartitionPending, now) + changed = true + } + + // Ensure the instance is added as partition owner. + if ring.AddOrUpdateOwner(l.cfg.InstanceID, OwnerActive, l.cfg.PartitionID, now) { + changed = true + } + + return changed, nil + }) +} + +func (l *PartitionInstanceLifecycler) waitPartitionAndRegisterOwner(ctx context.Context) error { + pollTicker := time.NewTicker(l.cfg.PollingInterval) + defer pollTicker.Stop() + + // Wait until the partition exists. + checkPartitionExist := func() (bool, error) { + level.Info(l.logger).Log("msg", "checking if the partition exist in the ring", "partition", l.cfg.PartitionID) + + ring, err := l.getRing(ctx) + if err != nil { + return false, errors.Wrap(err, "read partition ring") + } + + if ring.HasPartition(l.cfg.PartitionID) { + level.Info(l.logger).Log("msg", "partition found in the ring", "partition", l.cfg.PartitionID) + return true, nil + } + + level.Info(l.logger).Log("msg", "partition not found in the ring", "partition", l.cfg.PartitionID) + return false, nil + } + + for { + if exists, err := checkPartitionExist(); err != nil { + return err + } else if exists { + break + } + + select { + case <-ctx.Done(): + return ctx.Err() + + case <-pollTicker.C: + // Throttle. + } + } + + // Ensure the instance is added as partition owner. + return l.updateRing(ctx, func(ring *PartitionRingDesc) (bool, error) { + return ring.AddOrUpdateOwner(l.cfg.InstanceID, OwnerActive, l.cfg.PartitionID, time.Now()), nil + }) +} + +// reconcileOwnedPartition reconciles the owned partition. +// This function should be called periodically. +func (l *PartitionInstanceLifecycler) reconcileOwnedPartition(ctx context.Context, now time.Time) { + const reconcileType = "owned-partition" + l.reconcilesTotal.WithLabelValues(reconcileType).Inc() + + err := l.updateRing(ctx, func(ring *PartitionRingDesc) (bool, error) { + partitionID := l.cfg.PartitionID + + partition, exists := ring.Partitions[partitionID] + if !exists { + return false, ErrPartitionDoesNotExist + } + + // A pending partition should be switched to active if there are enough owners that + // have been added since more than the waiting period. + if partition.IsPending() && ring.PartitionOwnersCountUpdatedBefore(partitionID, now.Add(-l.cfg.WaitOwnersDurationOnPending)) >= l.cfg.WaitOwnersCountOnPending { + level.Info(l.logger).Log("msg", "switching partition state because enough owners have been registered and minimum waiting time has elapsed", "partition", l.cfg.PartitionID, "from_state", PartitionPending, "to_state", PartitionActive) + return ring.UpdatePartitionState(partitionID, PartitionActive, now), nil + } + + return false, nil + }) + + if err != nil { + l.reconcilesFailedTotal.WithLabelValues(reconcileType).Inc() + level.Warn(l.logger).Log("msg", "failed to reconcile owned partition", "partition", l.cfg.PartitionID, "err", err) + } +} + +// reconcileOtherPartitions reconciles other partitions. +// This function should be called periodically. +func (l *PartitionInstanceLifecycler) reconcileOtherPartitions(ctx context.Context, now time.Time) { + const reconcileType = "other-partitions" + l.reconcilesTotal.WithLabelValues(reconcileType).Inc() + + err := l.updateRing(ctx, func(ring *PartitionRingDesc) (bool, error) { + changed := false + + if l.cfg.DeleteInactivePartitionAfterDuration > 0 { + deleteBefore := now.Add(-l.cfg.DeleteInactivePartitionAfterDuration) + + for partitionID, partition := range ring.Partitions { + // Never delete the partition owned by this lifecycler, since it's expected to have at least + // this instance as owner. + if partitionID == l.cfg.PartitionID { + continue + } + + // A partition is safe to be removed only if it's inactive since longer than the wait period + // and it has no owners registered. + if partition.IsInactiveSince(deleteBefore) && ring.PartitionOwnersCount(partitionID) == 0 { + level.Info(l.logger).Log("msg", "removing inactive partition with no owners from ring", "partition", partitionID, "state", partition.State.CleanName(), "state_timestamp", partition.GetStateTime().String()) + ring.RemovePartition(partitionID) + changed = true + } + } + } + + return changed, nil + }) + + if err != nil { + l.reconcilesFailedTotal.WithLabelValues(reconcileType).Inc() + level.Warn(l.logger).Log("msg", "failed to reconcile other partitions", "err", err) + } +} + +func isPartitionStateChangeAllowed(from, to PartitionState) bool { + for _, allowed := range allowedPartitionStateChanges[from] { + if to == allowed { + return true + } + } + + return false +} diff --git a/vendor/github.com/grafana/dskit/ring/partition_instance_ring.go b/vendor/github.com/grafana/dskit/ring/partition_instance_ring.go new file mode 100644 index 000000000000..2fb15d8af98d --- /dev/null +++ b/vendor/github.com/grafana/dskit/ring/partition_instance_ring.go @@ -0,0 +1,150 @@ +package ring + +import ( + "fmt" + "time" + + "golang.org/x/exp/slices" +) + +type PartitionRingReader interface { + // PartitionRing returns a snapshot of the PartitionRing. This function must never return nil. + // If the ring is empty or unknown, an empty PartitionRing can be returned. + PartitionRing() *PartitionRing +} + +// PartitionInstanceRing holds a partitions ring and a instances ring, and provide functions +// to look up the intersection of the two (e.g. healthy instances by partition). +type PartitionInstanceRing struct { + partitionsRingReader PartitionRingReader + instancesRing *Ring + heartbeatTimeout time.Duration +} + +func NewPartitionInstanceRing(partitionsRingWatcher PartitionRingReader, instancesRing *Ring, heartbeatTimeout time.Duration) *PartitionInstanceRing { + return &PartitionInstanceRing{ + partitionsRingReader: partitionsRingWatcher, + instancesRing: instancesRing, + heartbeatTimeout: heartbeatTimeout, + } +} + +func (r *PartitionInstanceRing) PartitionRing() *PartitionRing { + return r.partitionsRingReader.PartitionRing() +} + +func (r *PartitionInstanceRing) InstanceRing() *Ring { + return r.instancesRing +} + +// GetReplicationSetsForOperation returns one ReplicationSet for each partition in the ring. +// A ReplicationSet is returned for every partition in ring. If there are no healthy owners +// for a partition, an error is returned. +func (r *PartitionInstanceRing) GetReplicationSetsForOperation(op Operation) ([]ReplicationSet, error) { + partitionsRing := r.PartitionRing() + partitionsRingDesc := partitionsRing.desc + + if len(partitionsRingDesc.Partitions) == 0 { + return nil, ErrEmptyRing + } + + now := time.Now() + result := make([]ReplicationSet, 0, len(partitionsRingDesc.Partitions)) + zonesBuffer := make([]string, 0, 3) // Pre-allocate buffer assuming 3 zones. + + for partitionID := range partitionsRingDesc.Partitions { + ownerIDs := partitionsRing.PartitionOwnerIDs(partitionID) + instances := make([]InstanceDesc, 0, len(ownerIDs)) + + for _, instanceID := range ownerIDs { + instance, err := r.instancesRing.GetInstance(instanceID) + if err != nil { + // If an instance doesn't exist in the instances ring we don't return an error + // but lookup for other instances of the partition. + continue + } + + if !instance.IsHealthy(op, r.heartbeatTimeout, now) { + continue + } + + instances = append(instances, instance) + } + + if len(instances) == 0 { + return nil, fmt.Errorf("partition %d: %w", partitionID, ErrTooManyUnhealthyInstances) + } + + // Count the number of unique zones among instances. + zonesBuffer = uniqueZonesFromInstances(instances, zonesBuffer[:0]) + uniqueZones := len(zonesBuffer) + + result = append(result, ReplicationSet{ + Instances: instances, + + // Partitions has no concept of zone, but we enable it in order to support ring's requests + // minimization feature. + ZoneAwarenessEnabled: true, + + // We need response from at least 1 owner. The assumption is that we have 1 owner per zone + // but it's not guaranteed (depends on how the application was deployed). The safest thing + // we can do here is to just request a successful response from at least 1 zone. + MaxUnavailableZones: uniqueZones - 1, + }) + } + return result, nil +} + +// ShuffleShard wraps PartitionRing.ShuffleShard(). +// +// The PartitionRing embedded in the returned PartitionInstanceRing is based on a snapshot of the partitions ring +// at the time this function gets called. This means that subsequent changes to the partitions ring will not +// be reflected in the returned PartitionInstanceRing. +func (r *PartitionInstanceRing) ShuffleShard(identifier string, size int) (*PartitionInstanceRing, error) { + partitionsSubring, err := r.PartitionRing().ShuffleShard(identifier, size) + if err != nil { + return nil, err + } + + return NewPartitionInstanceRing(newStaticPartitionRingReader(partitionsSubring), r.instancesRing, r.heartbeatTimeout), nil +} + +// ShuffleShardWithLookback wraps PartitionRing.ShuffleShardWithLookback(). +// +// The PartitionRing embedded in the returned PartitionInstanceRing is based on a snapshot of the partitions ring +// at the time this function gets called. This means that subsequent changes to the partitions ring will not +// be reflected in the returned PartitionInstanceRing. +func (r *PartitionInstanceRing) ShuffleShardWithLookback(identifier string, size int, lookbackPeriod time.Duration, now time.Time) (*PartitionInstanceRing, error) { + partitionsSubring, err := r.PartitionRing().ShuffleShardWithLookback(identifier, size, lookbackPeriod, now) + if err != nil { + return nil, err + } + + return NewPartitionInstanceRing(newStaticPartitionRingReader(partitionsSubring), r.instancesRing, r.heartbeatTimeout), nil +} + +type staticPartitionRingReader struct { + ring *PartitionRing +} + +func newStaticPartitionRingReader(ring *PartitionRing) staticPartitionRingReader { + return staticPartitionRingReader{ + ring: ring, + } +} + +func (m staticPartitionRingReader) PartitionRing() *PartitionRing { + return m.ring +} + +// uniqueZonesFromInstances returns the unique list of zones among the input instances. The input buf MUST have +// zero length, but could be capacity in order to avoid memory allocations. +func uniqueZonesFromInstances(instances []InstanceDesc, buf []string) []string { + for _, instance := range instances { + if !slices.Contains(buf, instance.Zone) { + buf = append(buf, instance.Zone) + } + } + + return buf +} diff --git a/vendor/github.com/grafana/dskit/ring/partition_ring.go b/vendor/github.com/grafana/dskit/ring/partition_ring.go new file mode 100644 index 000000000000..911de476c865 --- /dev/null +++ b/vendor/github.com/grafana/dskit/ring/partition_ring.go @@ -0,0 +1,487 @@ +package ring + +import ( + "bytes" + "fmt" + "math" + "math/rand" + "strconv" + "time" + + "golang.org/x/exp/slices" + + shardUtil "github.com/grafana/dskit/ring/shard" +) + +var ErrNoActivePartitionFound = fmt.Errorf("no active partition found") + +// PartitionRing holds an immutable view of the partitions ring. +// +// Design principles: +// - Immutable: the PartitionRingDesc hold by PartitionRing is immutable. When PartitionRingDesc changes +// a new instance of PartitionRing should be created. The partitions ring is expected to change infrequently +// (e.g. there's no heartbeat), so creating a new PartitionRing each time the partitions ring changes is +// not expected to have a significant overhead. +type PartitionRing struct { + // desc is a snapshot of the partition ring. This data is immutable and MUST NOT be modified. + desc PartitionRingDesc + + // ringTokens is a sorted list of all tokens registered by all partitions. + ringTokens Tokens + + // partitionByToken is a map where they key is a registered token and the value is ID of the partition + // that registered that token. + partitionByToken map[Token]int32 + + // ownersByPartition is a map where the key is the partition ID and the value is a list of owner IDs. + ownersByPartition map[int32][]string + + // shuffleShardCache is used to cache subrings generated with shuffle sharding. + shuffleShardCache *partitionRingShuffleShardCache + + // activePartitionsCount is a saved count of active partitions to avoid recomputing it. + activePartitionsCount int +} + +func NewPartitionRing(desc PartitionRingDesc) *PartitionRing { + return &PartitionRing{ + desc: desc, + ringTokens: desc.tokens(), + partitionByToken: desc.partitionByToken(), + ownersByPartition: desc.ownersByPartition(), + activePartitionsCount: desc.activePartitionsCount(), + shuffleShardCache: newPartitionRingShuffleShardCache(), + } +} + +// ActivePartitionForKey returns partition for the given key. Only active partitions are considered. +// Only one partition is returned: in other terms, the replication factor is always 1. +func (r *PartitionRing) ActivePartitionForKey(key uint32) (int32, error) { + var ( + start = searchToken(r.ringTokens, key) + iterations = 0 + tokensCount = len(r.ringTokens) + ) + + for i := start; iterations < tokensCount; i++ { + iterations++ + + if i >= tokensCount { + i %= len(r.ringTokens) + } + + token := r.ringTokens[i] + + partitionID, ok := r.partitionByToken[Token(token)] + if !ok { + return 0, ErrInconsistentTokensInfo + } + + partition, ok := r.desc.Partitions[partitionID] + if !ok { + return 0, ErrInconsistentTokensInfo + } + + // If the partition is not active we'll keep walking the ring. + if partition.IsActive() { + return partitionID, nil + } + } + + return 0, ErrNoActivePartitionFound +} + +// ShuffleShardSize returns number of partitions that would be in the result of ShuffleShard call with the same size. +func (r *PartitionRing) ShuffleShardSize(size int) int { + if size <= 0 || size > r.activePartitionsCount { + return r.activePartitionsCount + } + + if size < r.activePartitionsCount { + return size + } + return r.activePartitionsCount +} + +// ShuffleShard returns a subring for the provided identifier (eg. a tenant ID) +// and size (number of partitions). +// +// The algorithm used to build the subring is a shuffle sharder based on probabilistic +// hashing. We pick N unique partitions, walking the ring starting from random but +// predictable numbers. The random generator is initialised with a seed based on the +// provided identifier. +// +// This function returns a subring containing ONLY ACTIVE partitions. +// +// This function supports caching. +// +// This implementation guarantees: +// +// - Stability: given the same ring, two invocations returns the same result. +// +// - Consistency: adding/removing 1 partition from the ring generates a resulting +// subring with no more then 1 difference. +// +// - Shuffling: probabilistically, for a large enough cluster each identifier gets a different +// set of instances, with a reduced number of overlapping instances between two identifiers. +func (r *PartitionRing) ShuffleShard(identifier string, size int) (*PartitionRing, error) { + if cached := r.shuffleShardCache.getSubring(identifier, size); cached != nil { + return cached, nil + } + + // No need to pass the time if there's no lookback. + subring, err := r.shuffleShard(identifier, size, 0, time.Time{}) + if err != nil { + return nil, err + } + + r.shuffleShardCache.setSubring(identifier, size, subring) + return subring, nil +} + +// ShuffleShardWithLookback is like ShuffleShard() but the returned subring includes all instances +// that have been part of the identifier's shard in [now - lookbackPeriod, now] time window. +// +// This function can return a mix of ACTIVE and INACTIVE partitions. INACTIVE partitions are only +// included if they were part of the identifier's shard within the lookbackPeriod. PENDING partitions +// are never returned. +// +// This function supports caching, but the cache will only be effective if successive calls for the +// same identifier are with the same lookbackPeriod and increasing values of now. +func (r *PartitionRing) ShuffleShardWithLookback(identifier string, size int, lookbackPeriod time.Duration, now time.Time) (*PartitionRing, error) { + if cached := r.shuffleShardCache.getSubringWithLookback(identifier, size, lookbackPeriod, now); cached != nil { + return cached, nil + } + + subring, err := r.shuffleShard(identifier, size, lookbackPeriod, now) + if err != nil { + return nil, err + } + + r.shuffleShardCache.setSubringWithLookback(identifier, size, lookbackPeriod, now, subring) + return subring, nil +} + +func (r *PartitionRing) shuffleShard(identifier string, size int, lookbackPeriod time.Duration, now time.Time) (*PartitionRing, error) { + // If the size is too small or too large, run with a size equal to the total number of partitions. + // We have to run the function anyway because the logic may filter out some INACTIVE partitions. + if size <= 0 || size >= len(r.desc.Partitions) { + size = len(r.desc.Partitions) + } + + var lookbackUntil int64 + if lookbackPeriod > 0 { + lookbackUntil = now.Add(-lookbackPeriod).Unix() + } + + // Initialise the random generator used to select instances in the ring. + // There are no zones + random := rand.New(rand.NewSource(shardUtil.ShuffleShardSeed(identifier, ""))) + + // To select one more instance while guaranteeing the "consistency" property, + // we do pick a random value from the generator and resolve uniqueness collisions + // (if any) continuing walking the ring. + tokensCount := len(r.ringTokens) + + result := make(map[int32]struct{}, size) + exclude := map[int32]struct{}{} + + for len(result) < size { + start := searchToken(r.ringTokens, random.Uint32()) + iterations := 0 + found := false + + for p := start; !found && iterations < tokensCount; p++ { + iterations++ + + // Wrap p around in the ring. + if p >= tokensCount { + p %= tokensCount + } + + pid, ok := r.partitionByToken[Token(r.ringTokens[p])] + if !ok { + return nil, ErrInconsistentTokensInfo + } + + // Ensure the partition has not already been included or excluded. + if _, ok := result[pid]; ok { + continue + } + if _, ok := exclude[pid]; ok { + continue + } + + p, ok := r.desc.Partitions[pid] + if !ok { + return nil, ErrInconsistentTokensInfo + } + + // PENDING partitions should be skipped because they're not ready for read or write yet, + // and they don't need to be looked back. + if p.IsPending() { + exclude[pid] = struct{}{} + continue + } + + var ( + withinLookbackPeriod = lookbackPeriod > 0 && p.GetStateTimestamp() >= lookbackUntil + shouldExtend = withinLookbackPeriod + shouldInclude = p.IsActive() || withinLookbackPeriod + ) + + // Either include or exclude the found partition. + if shouldInclude { + result[pid] = struct{}{} + } else { + exclude[pid] = struct{}{} + } + + // Extend the shard, if requested. + if shouldExtend { + size++ + } + + // We can stop searching for other partitions only if this partition was included + // and no extension was requested, which means it's the "stop partition" for this cycle. + if shouldInclude && !shouldExtend { + found = true + } + } + + // If we iterated over all tokens, and no new partition has been found, we can stop looking for more partitions. + if !found { + break + } + } + + return NewPartitionRing(r.desc.WithPartitions(result)), nil +} + +// PartitionsCount returns the number of partitions in the ring. +func (r *PartitionRing) PartitionsCount() int { + return len(r.desc.Partitions) +} + +// ActivePartitionsCount returns the number of active partitions in the ring. +func (r *PartitionRing) ActivePartitionsCount() int { + return r.activePartitionsCount +} + +// Partitions returns the partitions in the ring. +// The returned slice is a deep copy, so the caller can freely manipulate it. +func (r *PartitionRing) Partitions() []PartitionDesc { + res := make([]PartitionDesc, 0, len(r.desc.Partitions)) + + for _, partition := range r.desc.Partitions { + res = append(res, partition.Clone()) + } + + return res +} + +// PartitionIDs returns a sorted list of all partition IDs in the ring. +// The returned slice is a copy, so the caller can freely manipulate it. +func (r *PartitionRing) PartitionIDs() []int32 { + ids := make([]int32, 0, len(r.desc.Partitions)) + + for id := range r.desc.Partitions { + ids = append(ids, id) + } + + slices.Sort(ids) + return ids +} + +// PendingPartitionIDs returns a sorted list of all PENDING partition IDs in the ring. +// The returned slice is a copy, so the caller can freely manipulate it. +func (r *PartitionRing) PendingPartitionIDs() []int32 { + ids := make([]int32, 0, len(r.desc.Partitions)) + + for id, partition := range r.desc.Partitions { + if partition.IsPending() { + ids = append(ids, id) + } + } + + slices.Sort(ids) + return ids +} + +// ActivePartitionIDs returns a sorted list of all ACTIVE partition IDs in the ring. +// The returned slice is a copy, so the caller can freely manipulate it. +func (r *PartitionRing) ActivePartitionIDs() []int32 { + ids := make([]int32, 0, len(r.desc.Partitions)) + + for id, partition := range r.desc.Partitions { + if partition.IsActive() { + ids = append(ids, id) + } + } + + slices.Sort(ids) + return ids +} + +// InactivePartitionIDs returns a sorted list of all INACTIVE partition IDs in the ring. +// The returned slice is a copy, so the caller can freely manipulate it. +func (r *PartitionRing) InactivePartitionIDs() []int32 { + ids := make([]int32, 0, len(r.desc.Partitions)) + + for id, partition := range r.desc.Partitions { + if partition.IsInactive() { + ids = append(ids, id) + } + } + + slices.Sort(ids) + return ids +} + +// PartitionOwnerIDs returns a list of owner IDs for the given partitionID. +// The returned slice is NOT a copy and should be never modified by the caller. +func (r *PartitionRing) PartitionOwnerIDs(partitionID int32) (doNotModify []string) { + return r.ownersByPartition[partitionID] +} + +// PartitionOwnerIDsCopy is like PartitionOwnerIDs(), but the returned slice is a copy, +// so the caller can freely manipulate it. +func (r *PartitionRing) PartitionOwnerIDsCopy(partitionID int32) []string { + ids := r.ownersByPartition[partitionID] + if len(ids) == 0 { + return nil + } + + return slices.Clone(ids) +} + +func (r *PartitionRing) String() string { + buf := bytes.Buffer{} + for pid, pd := range r.desc.Partitions { + buf.WriteString(fmt.Sprintf(" %d:%v", pid, pd.State.String())) + } + + return fmt.Sprintf("PartitionRing{ownersCount: %d, partitionsCount: %d, partitions: {%s}}", len(r.desc.Owners), len(r.desc.Partitions), buf.String()) +} + +// GetTokenRangesForPartition returns token-range owned by given partition. Note that this +// method does NOT take partition state into account, so if only active partitions should be +// considered, then PartitionRing with only active partitions must be created first (e.g. using ShuffleShard method). +func (r *PartitionRing) GetTokenRangesForPartition(partitionID int32) (TokenRanges, error) { + partition, ok := r.desc.Partitions[partitionID] + if !ok { + return nil, ErrPartitionDoesNotExist + } + + // 1 range (2 values) per token + one additional if we need to split the rollover range. + ranges := make(TokenRanges, 0, 2*(len(partition.Tokens)+1)) + + addRange := func(start, end uint32) { + // check if we can group ranges. If so, we just update end of previous range. + if len(ranges) > 0 && ranges[len(ranges)-1] == start-1 { + ranges[len(ranges)-1] = end + } else { + ranges = append(ranges, start, end) + } + } + + // "last" range is range that includes token math.MaxUint32. + ownsLastRange := false + startOfLastRange := uint32(0) + + // We start with all tokens, but will remove tokens we already skipped, to let binary search do less work. + ringTokens := r.ringTokens + + for iter, t := range partition.Tokens { + lastOwnedToken := t - 1 + + ix := searchToken(ringTokens, lastOwnedToken) + prevIx := ix - 1 + + if prevIx < 0 { + // We can only find "last" range during first iteration. + if iter > 0 { + return nil, ErrInconsistentTokensInfo + } + + prevIx = len(ringTokens) - 1 + ownsLastRange = true + + startOfLastRange = ringTokens[prevIx] + + // We can only claim token 0 if our actual token in the ring (which is exclusive end of range) was not 0. + if t > 0 { + addRange(0, lastOwnedToken) + } + } else { + addRange(ringTokens[prevIx], lastOwnedToken) + } + + // Reduce number of tokens we need to search through. We keep current token to serve as min boundary for next search, + // to make sure we don't find another "last" range (where prevIx < 0). + ringTokens = ringTokens[ix:] + } + + if ownsLastRange { + addRange(startOfLastRange, math.MaxUint32) + } + + return ranges, nil +} + +// ActivePartitionBatchRing wraps PartitionRing and implements DoBatchRing to lookup ACTIVE partitions. +type ActivePartitionBatchRing struct { + ring *PartitionRing +} + +func NewActivePartitionBatchRing(ring *PartitionRing) *ActivePartitionBatchRing { + return &ActivePartitionBatchRing{ + ring: ring, + } +} + +// InstancesCount returns the number of active partitions in the ring. +// +// InstancesCount implements DoBatchRing.InstancesCount. +func (r *ActivePartitionBatchRing) InstancesCount() int { + return r.ring.ActivePartitionsCount() +} + +// ReplicationFactor returns 1 as partitions replication factor: an entry (looked by key via Get()) +// is always stored in 1 and only 1 partition. +// +// ReplicationFactor implements DoBatchRing.ReplicationFactor. +func (r *ActivePartitionBatchRing) ReplicationFactor() int { + return 1 +} + +// Get implements DoBatchRing.Get. +func (r *ActivePartitionBatchRing) Get(key uint32, _ Operation, bufInstances []InstanceDesc, _, _ []string) (ReplicationSet, error) { + partitionID, err := r.ring.ActivePartitionForKey(key) + if err != nil { + return ReplicationSet{}, err + } + + // Ensure we have enough capacity in bufInstances. + if cap(bufInstances) < 1 { + bufInstances = []InstanceDesc{{}} + } else { + bufInstances = bufInstances[:1] + } + + partitionIDString := strconv.Itoa(int(partitionID)) + + bufInstances[0] = InstanceDesc{ + Addr: partitionIDString, + Timestamp: 0, + State: ACTIVE, + Id: partitionIDString, + } + + return ReplicationSet{ + Instances: bufInstances, + MaxErrors: 0, + MaxUnavailableZones: 0, + ZoneAwarenessEnabled: false, + }, nil +} diff --git a/vendor/github.com/grafana/dskit/ring/partition_ring_desc.pb.go b/vendor/github.com/grafana/dskit/ring/partition_ring_desc.pb.go new file mode 100644 index 000000000000..8f47b1c562ea --- /dev/null +++ b/vendor/github.com/grafana/dskit/ring/partition_ring_desc.pb.go @@ -0,0 +1,1545 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: partition_ring_desc.proto + +package ring + +import ( + fmt "fmt" + _ "github.com/gogo/protobuf/gogoproto" + proto "github.com/gogo/protobuf/proto" + github_com_gogo_protobuf_sortkeys "github.com/gogo/protobuf/sortkeys" + io "io" + math "math" + math_bits "math/bits" + reflect "reflect" + strconv "strconv" + strings "strings" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type PartitionState int32 + +const ( + PartitionUnknown PartitionState = 0 + // Pending partition is a partition that is about to be switched to ACTIVE. This state is used + // to let owners to attach to the partition and get ready to handle the partition. + // + // When a partition is in this state, it must not be used for writing or reading. + PartitionPending PartitionState = 1 + // Active partition in read-write mode. + PartitionActive PartitionState = 2 + // Inactive partition in read-only mode. This partition will be deleted after a grace period, + // unless its state changes to Active again. + PartitionInactive PartitionState = 3 + // Deleted partition. This state is not visible to ring clients: it's only used to propagate + // via memberlist the information that a partition has been deleted. + PartitionDeleted PartitionState = 4 +) + +var PartitionState_name = map[int32]string{ + 0: "PartitionUnknown", + 1: "PartitionPending", + 2: "PartitionActive", + 3: "PartitionInactive", + 4: "PartitionDeleted", +} + +var PartitionState_value = map[string]int32{ + "PartitionUnknown": 0, + "PartitionPending": 1, + "PartitionActive": 2, + "PartitionInactive": 3, + "PartitionDeleted": 4, +} + +func (PartitionState) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_4df2762174d93dc4, []int{0} +} + +type OwnerState int32 + +const ( + OwnerUnknown OwnerState = 0 + // Active owner. + OwnerActive OwnerState = 1 + // Deleted owner. This state is not visible to ring clients: it's only used to propagate + // via memberlist the information that a owner has been deleted. Owners in this state + // are removed before client can see them. + OwnerDeleted OwnerState = 2 +) + +var OwnerState_name = map[int32]string{ + 0: "OwnerUnknown", + 1: "OwnerActive", + 2: "OwnerDeleted", +} + +var OwnerState_value = map[string]int32{ + "OwnerUnknown": 0, + "OwnerActive": 1, + "OwnerDeleted": 2, +} + +func (OwnerState) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_4df2762174d93dc4, []int{1} +} + +// PartitionRingDesc holds the state of the partitions ring. +type PartitionRingDesc struct { + // Mapping between partition ID and partition info. + Partitions map[int32]PartitionDesc `protobuf:"bytes,1,rep,name=partitions,proto3" json:"partitions" protobuf_key:"varint,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + // Mapping between instance ID and partition ownership info. + Owners map[string]OwnerDesc `protobuf:"bytes,2,rep,name=owners,proto3" json:"owners" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` +} + +func (m *PartitionRingDesc) Reset() { *m = PartitionRingDesc{} } +func (*PartitionRingDesc) ProtoMessage() {} +func (*PartitionRingDesc) Descriptor() ([]byte, []int) { + return fileDescriptor_4df2762174d93dc4, []int{0} +} +func (m *PartitionRingDesc) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *PartitionRingDesc) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_PartitionRingDesc.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *PartitionRingDesc) XXX_Merge(src proto.Message) { + xxx_messageInfo_PartitionRingDesc.Merge(m, src) +} +func (m *PartitionRingDesc) XXX_Size() int { + return m.Size() +} +func (m *PartitionRingDesc) XXX_DiscardUnknown() { + xxx_messageInfo_PartitionRingDesc.DiscardUnknown(m) +} + +var xxx_messageInfo_PartitionRingDesc proto.InternalMessageInfo + +func (m *PartitionRingDesc) GetPartitions() map[int32]PartitionDesc { + if m != nil { + return m.Partitions + } + return nil +} + +func (m *PartitionRingDesc) GetOwners() map[string]OwnerDesc { + if m != nil { + return m.Owners + } + return nil +} + +// PartitionDesc holds the state of a single partition. +type PartitionDesc struct { + // The partition ID. This value is the same as the key in the partitions map in PartitionRingDesc. + Id int32 `protobuf:"varint,4,opt,name=id,proto3" json:"id,omitempty"` + // Unique tokens, generated with deterministic token generator. Tokens MUST be immutable: + // if tokens get changed, the change will not be propagated via memberlist. + Tokens []uint32 `protobuf:"varint,1,rep,packed,name=tokens,proto3" json:"tokens,omitempty"` + // The state of the partition. + State PartitionState `protobuf:"varint,2,opt,name=state,proto3,enum=ring.PartitionState" json:"state,omitempty"` + // Unix timestamp (with seconds precision) of when has the state changed last time for this partition. + StateTimestamp int64 `protobuf:"varint,3,opt,name=stateTimestamp,proto3" json:"stateTimestamp,omitempty"` +} + +func (m *PartitionDesc) Reset() { *m = PartitionDesc{} } +func (*PartitionDesc) ProtoMessage() {} +func (*PartitionDesc) Descriptor() ([]byte, []int) { + return fileDescriptor_4df2762174d93dc4, []int{1} +} +func (m *PartitionDesc) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *PartitionDesc) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_PartitionDesc.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *PartitionDesc) XXX_Merge(src proto.Message) { + xxx_messageInfo_PartitionDesc.Merge(m, src) +} +func (m *PartitionDesc) XXX_Size() int { + return m.Size() +} +func (m *PartitionDesc) XXX_DiscardUnknown() { + xxx_messageInfo_PartitionDesc.DiscardUnknown(m) +} + +var xxx_messageInfo_PartitionDesc proto.InternalMessageInfo + +func (m *PartitionDesc) GetId() int32 { + if m != nil { + return m.Id + } + return 0 +} + +func (m *PartitionDesc) GetTokens() []uint32 { + if m != nil { + return m.Tokens + } + return nil +} + +func (m *PartitionDesc) GetState() PartitionState { + if m != nil { + return m.State + } + return PartitionUnknown +} + +func (m *PartitionDesc) GetStateTimestamp() int64 { + if m != nil { + return m.StateTimestamp + } + return 0 +} + +// OwnerDesc holds the information of a partition owner. +type OwnerDesc struct { + // Partition that belongs to this owner. A owner can own only 1 partition, but 1 partition can be + // owned by multiple owners. + OwnedPartition int32 `protobuf:"varint,1,opt,name=ownedPartition,proto3" json:"ownedPartition,omitempty"` + // The owner state. This field is used to propagate deletions via memberlist. + State OwnerState `protobuf:"varint,2,opt,name=state,proto3,enum=ring.OwnerState" json:"state,omitempty"` + // Unix timestamp (with seconds precision) of when the data for the owner has been updated the last time. + // This timestamp is used to resolve conflicts when merging updates via memberlist (the most recent + // update wins). + UpdatedTimestamp int64 `protobuf:"varint,3,opt,name=updatedTimestamp,proto3" json:"updatedTimestamp,omitempty"` +} + +func (m *OwnerDesc) Reset() { *m = OwnerDesc{} } +func (*OwnerDesc) ProtoMessage() {} +func (*OwnerDesc) Descriptor() ([]byte, []int) { + return fileDescriptor_4df2762174d93dc4, []int{2} +} +func (m *OwnerDesc) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *OwnerDesc) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_OwnerDesc.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *OwnerDesc) XXX_Merge(src proto.Message) { + xxx_messageInfo_OwnerDesc.Merge(m, src) +} +func (m *OwnerDesc) XXX_Size() int { + return m.Size() +} +func (m *OwnerDesc) XXX_DiscardUnknown() { + xxx_messageInfo_OwnerDesc.DiscardUnknown(m) +} + +var xxx_messageInfo_OwnerDesc proto.InternalMessageInfo + +func (m *OwnerDesc) GetOwnedPartition() int32 { + if m != nil { + return m.OwnedPartition + } + return 0 +} + +func (m *OwnerDesc) GetState() OwnerState { + if m != nil { + return m.State + } + return OwnerUnknown +} + +func (m *OwnerDesc) GetUpdatedTimestamp() int64 { + if m != nil { + return m.UpdatedTimestamp + } + return 0 +} + +func init() { + proto.RegisterEnum("ring.PartitionState", PartitionState_name, PartitionState_value) + proto.RegisterEnum("ring.OwnerState", OwnerState_name, OwnerState_value) + proto.RegisterType((*PartitionRingDesc)(nil), "ring.PartitionRingDesc") + proto.RegisterMapType((map[string]OwnerDesc)(nil), "ring.PartitionRingDesc.OwnersEntry") + proto.RegisterMapType((map[int32]PartitionDesc)(nil), "ring.PartitionRingDesc.PartitionsEntry") + proto.RegisterType((*PartitionDesc)(nil), "ring.PartitionDesc") + proto.RegisterType((*OwnerDesc)(nil), "ring.OwnerDesc") +} + +func init() { proto.RegisterFile("partition_ring_desc.proto", fileDescriptor_4df2762174d93dc4) } + +var fileDescriptor_4df2762174d93dc4 = []byte{ + // 497 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x74, 0x93, 0x31, 0x6f, 0xd3, 0x40, + 0x14, 0xc7, 0x7d, 0x76, 0x12, 0xa9, 0x2f, 0x34, 0x39, 0xae, 0x05, 0x99, 0x0c, 0x47, 0x14, 0x44, + 0x09, 0x91, 0x48, 0xa5, 0xc0, 0x80, 0xd8, 0x52, 0x95, 0x01, 0x24, 0x44, 0x65, 0x60, 0xae, 0x9c, + 0xf8, 0x30, 0xa7, 0x34, 0x77, 0x91, 0x7d, 0x6e, 0xd5, 0x05, 0xb1, 0x31, 0xb0, 0xf0, 0x31, 0xf8, + 0x22, 0x48, 0x1d, 0x33, 0x76, 0x42, 0xc4, 0x59, 0x18, 0xfb, 0x11, 0x90, 0xcf, 0xae, 0x63, 0xbb, + 0xea, 0x76, 0xef, 0x7f, 0xef, 0xfd, 0xfe, 0xff, 0x3b, 0x9f, 0xe1, 0xc1, 0xc2, 0x0d, 0x14, 0x57, + 0x5c, 0x8a, 0xe3, 0x80, 0x0b, 0xff, 0xd8, 0x63, 0xe1, 0x74, 0xb8, 0x08, 0xa4, 0x92, 0xa4, 0x96, + 0x08, 0x9d, 0x67, 0x3e, 0x57, 0x5f, 0xa2, 0xc9, 0x70, 0x2a, 0xe7, 0xfb, 0xbe, 0xf4, 0xe5, 0xbe, + 0xde, 0x9c, 0x44, 0x9f, 0x75, 0xa5, 0x0b, 0xbd, 0x4a, 0x87, 0x7a, 0xbf, 0x4d, 0xb8, 0x7b, 0x74, + 0x8d, 0x74, 0xb8, 0xf0, 0x0f, 0x59, 0x38, 0x25, 0xef, 0x00, 0x72, 0x9f, 0xd0, 0x46, 0x5d, 0xab, + 0xdf, 0x1c, 0x3d, 0x19, 0x26, 0xfc, 0xe1, 0x8d, 0xe6, 0x8d, 0x12, 0xbe, 0x16, 0x2a, 0x38, 0x3f, + 0xa8, 0x5d, 0xfc, 0x79, 0x68, 0x38, 0x05, 0x00, 0x19, 0x43, 0x43, 0x9e, 0x09, 0x16, 0x84, 0xb6, + 0xa9, 0x51, 0x8f, 0x6e, 0x43, 0xbd, 0xd7, 0x5d, 0x45, 0x4c, 0x36, 0xd8, 0x71, 0xa0, 0x5d, 0xf1, + 0x21, 0x18, 0xac, 0x19, 0x3b, 0xb7, 0x51, 0x17, 0xf5, 0xeb, 0x4e, 0xb2, 0x24, 0x4f, 0xa1, 0x7e, + 0xea, 0x9e, 0x44, 0xcc, 0x36, 0xbb, 0xa8, 0xdf, 0x1c, 0xed, 0x54, 0x6c, 0x12, 0x0b, 0x27, 0xed, + 0x78, 0x65, 0xbe, 0x44, 0x9d, 0xb7, 0xd0, 0x2c, 0x18, 0x16, 0x79, 0x5b, 0x29, 0xef, 0x71, 0x99, + 0xd7, 0x4e, 0x79, 0x7a, 0xa6, 0xc2, 0xea, 0xfd, 0x40, 0xb0, 0x5d, 0x32, 0x22, 0x2d, 0x30, 0xb9, + 0x67, 0xd7, 0x74, 0x3a, 0x93, 0x7b, 0xe4, 0x3e, 0x34, 0x94, 0x9c, 0xb1, 0xec, 0x3e, 0xb7, 0x9d, + 0xac, 0x22, 0x03, 0xa8, 0x87, 0xca, 0x55, 0xa9, 0x49, 0x6b, 0xb4, 0x5b, 0x09, 0xfd, 0x21, 0xd9, + 0x73, 0xd2, 0x16, 0xb2, 0x07, 0x2d, 0xbd, 0xf8, 0xc8, 0xe7, 0x2c, 0x54, 0xee, 0x7c, 0x61, 0x5b, + 0x5d, 0xd4, 0xb7, 0x9c, 0x8a, 0xda, 0xfb, 0x8e, 0x60, 0x2b, 0x8f, 0x99, 0x4c, 0x25, 0xb7, 0xe8, + 0xe5, 0xcc, 0xec, 0xce, 0x2a, 0x2a, 0xd9, 0x2b, 0x27, 0xc1, 0x85, 0xe3, 0x96, 0x52, 0x0c, 0x00, + 0x47, 0x0b, 0xcf, 0x55, 0xcc, 0xab, 0xe6, 0xb8, 0xa1, 0x0f, 0xbe, 0x42, 0xab, 0x7c, 0x14, 0xb2, + 0x0b, 0x38, 0x57, 0x3e, 0x89, 0x99, 0x90, 0x67, 0x02, 0x1b, 0x25, 0xf5, 0x88, 0x09, 0x8f, 0x0b, + 0x1f, 0x23, 0xb2, 0x53, 0xf8, 0xea, 0xe3, 0xa9, 0xe2, 0xa7, 0x0c, 0x9b, 0xe4, 0x5e, 0xe1, 0xc5, + 0xbe, 0x11, 0x6e, 0x2a, 0x5b, 0x25, 0xc2, 0x21, 0x3b, 0x61, 0x8a, 0x79, 0xb8, 0x36, 0x18, 0x03, + 0x6c, 0x0e, 0x40, 0x30, 0xdc, 0xd1, 0xd5, 0xc6, 0xb7, 0x9d, 0xbd, 0x81, 0x8c, 0x8e, 0xf2, 0x96, + 0x6b, 0x84, 0x79, 0xf0, 0x62, 0xb9, 0xa2, 0xc6, 0xe5, 0x8a, 0x1a, 0x57, 0x2b, 0x8a, 0xbe, 0xc5, + 0x14, 0xfd, 0x8a, 0x29, 0xba, 0x88, 0x29, 0x5a, 0xc6, 0x14, 0xfd, 0x8d, 0x29, 0xfa, 0x17, 0x53, + 0xe3, 0x2a, 0xa6, 0xe8, 0xe7, 0x9a, 0x1a, 0xcb, 0x35, 0x35, 0x2e, 0xd7, 0xd4, 0x98, 0x34, 0xf4, + 0xff, 0xf5, 0xfc, 0x7f, 0x00, 0x00, 0x00, 0xff, 0xff, 0xc9, 0xd1, 0xa7, 0xbd, 0xb1, 0x03, 0x00, + 0x00, +} + +func (x PartitionState) String() string { + s, ok := PartitionState_name[int32(x)] + if ok { + return s + } + return strconv.Itoa(int(x)) +} +func (x OwnerState) String() string { + s, ok := OwnerState_name[int32(x)] + if ok { + return s + } + return strconv.Itoa(int(x)) +} +func (this *PartitionRingDesc) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*PartitionRingDesc) + if !ok { + that2, ok := that.(PartitionRingDesc) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if len(this.Partitions) != len(that1.Partitions) { + return false + } + for i := range this.Partitions { + a := this.Partitions[i] + b := that1.Partitions[i] + if !(&a).Equal(&b) { + return false + } + } + if len(this.Owners) != len(that1.Owners) { + return false + } + for i := range this.Owners { + a := this.Owners[i] + b := that1.Owners[i] + if !(&a).Equal(&b) { + return false + } + } + return true +} +func (this *PartitionDesc) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*PartitionDesc) + if !ok { + that2, ok := that.(PartitionDesc) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if this.Id != that1.Id { + return false + } + if len(this.Tokens) != len(that1.Tokens) { + return false + } + for i := range this.Tokens { + if this.Tokens[i] != that1.Tokens[i] { + return false + } + } + if this.State != that1.State { + return false + } + if this.StateTimestamp != that1.StateTimestamp { + return false + } + return true +} +func (this *OwnerDesc) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*OwnerDesc) + if !ok { + that2, ok := that.(OwnerDesc) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if this.OwnedPartition != that1.OwnedPartition { + return false + } + if this.State != that1.State { + return false + } + if this.UpdatedTimestamp != that1.UpdatedTimestamp { + return false + } + return true +} +func (this *PartitionRingDesc) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 6) + s = append(s, "&ring.PartitionRingDesc{") + keysForPartitions := make([]int32, 0, len(this.Partitions)) + for k, _ := range this.Partitions { + keysForPartitions = append(keysForPartitions, k) + } + github_com_gogo_protobuf_sortkeys.Int32s(keysForPartitions) + mapStringForPartitions := "map[int32]PartitionDesc{" + for _, k := range keysForPartitions { + mapStringForPartitions += fmt.Sprintf("%#v: %#v,", k, this.Partitions[k]) + } + mapStringForPartitions += "}" + if this.Partitions != nil { + s = append(s, "Partitions: "+mapStringForPartitions+",\n") + } + keysForOwners := make([]string, 0, len(this.Owners)) + for k, _ := range this.Owners { + keysForOwners = append(keysForOwners, k) + } + github_com_gogo_protobuf_sortkeys.Strings(keysForOwners) + mapStringForOwners := "map[string]OwnerDesc{" + for _, k := range keysForOwners { + mapStringForOwners += fmt.Sprintf("%#v: %#v,", k, this.Owners[k]) + } + mapStringForOwners += "}" + if this.Owners != nil { + s = append(s, "Owners: "+mapStringForOwners+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} +func (this *PartitionDesc) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 8) + s = append(s, "&ring.PartitionDesc{") + s = append(s, "Id: "+fmt.Sprintf("%#v", this.Id)+",\n") + s = append(s, "Tokens: "+fmt.Sprintf("%#v", this.Tokens)+",\n") + s = append(s, "State: "+fmt.Sprintf("%#v", this.State)+",\n") + s = append(s, "StateTimestamp: "+fmt.Sprintf("%#v", this.StateTimestamp)+",\n") + s = append(s, "}") + return strings.Join(s, "") +} +func (this *OwnerDesc) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 7) + s = append(s, "&ring.OwnerDesc{") + s = append(s, "OwnedPartition: "+fmt.Sprintf("%#v", this.OwnedPartition)+",\n") + s = append(s, "State: "+fmt.Sprintf("%#v", this.State)+",\n") + s = append(s, "UpdatedTimestamp: "+fmt.Sprintf("%#v", this.UpdatedTimestamp)+",\n") + s = append(s, "}") + return strings.Join(s, "") +} +func valueToGoStringPartitionRingDesc(v interface{}, typ string) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("func(v %v) *%v { return &v } ( %#v )", typ, typ, pv) +} +func (m *PartitionRingDesc) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *PartitionRingDesc) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *PartitionRingDesc) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Owners) > 0 { + for k := range m.Owners { + v := m.Owners[k] + baseI := i + { + size, err := (&v).MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x12 + i -= len(k) + copy(dAtA[i:], k) + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(len(k))) + i-- + dAtA[i] = 0xa + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(baseI-i)) + i-- + dAtA[i] = 0x12 + } + } + if len(m.Partitions) > 0 { + for k := range m.Partitions { + v := m.Partitions[k] + baseI := i + { + size, err := (&v).MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x12 + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(k)) + i-- + dAtA[i] = 0x8 + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(baseI-i)) + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + +func (m *PartitionDesc) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *PartitionDesc) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *PartitionDesc) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Id != 0 { + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(m.Id)) + i-- + dAtA[i] = 0x20 + } + if m.StateTimestamp != 0 { + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(m.StateTimestamp)) + i-- + dAtA[i] = 0x18 + } + if m.State != 0 { + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(m.State)) + i-- + dAtA[i] = 0x10 + } + if len(m.Tokens) > 0 { + dAtA4 := make([]byte, len(m.Tokens)*10) + var j3 int + for _, num := range m.Tokens { + for num >= 1<<7 { + dAtA4[j3] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j3++ + } + dAtA4[j3] = uint8(num) + j3++ + } + i -= j3 + copy(dAtA[i:], dAtA4[:j3]) + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(j3)) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *OwnerDesc) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *OwnerDesc) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *OwnerDesc) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.UpdatedTimestamp != 0 { + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(m.UpdatedTimestamp)) + i-- + dAtA[i] = 0x18 + } + if m.State != 0 { + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(m.State)) + i-- + dAtA[i] = 0x10 + } + if m.OwnedPartition != 0 { + i = encodeVarintPartitionRingDesc(dAtA, i, uint64(m.OwnedPartition)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func encodeVarintPartitionRingDesc(dAtA []byte, offset int, v uint64) int { + offset -= sovPartitionRingDesc(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *PartitionRingDesc) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.Partitions) > 0 { + for k, v := range m.Partitions { + _ = k + _ = v + l = v.Size() + mapEntrySize := 1 + sovPartitionRingDesc(uint64(k)) + 1 + l + sovPartitionRingDesc(uint64(l)) + n += mapEntrySize + 1 + sovPartitionRingDesc(uint64(mapEntrySize)) + } + } + if len(m.Owners) > 0 { + for k, v := range m.Owners { + _ = k + _ = v + l = v.Size() + mapEntrySize := 1 + len(k) + sovPartitionRingDesc(uint64(len(k))) + 1 + l + sovPartitionRingDesc(uint64(l)) + n += mapEntrySize + 1 + sovPartitionRingDesc(uint64(mapEntrySize)) + } + } + return n +} + +func (m *PartitionDesc) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.Tokens) > 0 { + l = 0 + for _, e := range m.Tokens { + l += sovPartitionRingDesc(uint64(e)) + } + n += 1 + sovPartitionRingDesc(uint64(l)) + l + } + if m.State != 0 { + n += 1 + sovPartitionRingDesc(uint64(m.State)) + } + if m.StateTimestamp != 0 { + n += 1 + sovPartitionRingDesc(uint64(m.StateTimestamp)) + } + if m.Id != 0 { + n += 1 + sovPartitionRingDesc(uint64(m.Id)) + } + return n +} + +func (m *OwnerDesc) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.OwnedPartition != 0 { + n += 1 + sovPartitionRingDesc(uint64(m.OwnedPartition)) + } + if m.State != 0 { + n += 1 + sovPartitionRingDesc(uint64(m.State)) + } + if m.UpdatedTimestamp != 0 { + n += 1 + sovPartitionRingDesc(uint64(m.UpdatedTimestamp)) + } + return n +} + +func sovPartitionRingDesc(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozPartitionRingDesc(x uint64) (n int) { + return sovPartitionRingDesc(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (this *PartitionRingDesc) String() string { + if this == nil { + return "nil" + } + keysForPartitions := make([]int32, 0, len(this.Partitions)) + for k, _ := range this.Partitions { + keysForPartitions = append(keysForPartitions, k) + } + github_com_gogo_protobuf_sortkeys.Int32s(keysForPartitions) + mapStringForPartitions := "map[int32]PartitionDesc{" + for _, k := range keysForPartitions { + mapStringForPartitions += fmt.Sprintf("%v: %v,", k, this.Partitions[k]) + } + mapStringForPartitions += "}" + keysForOwners := make([]string, 0, len(this.Owners)) + for k, _ := range this.Owners { + keysForOwners = append(keysForOwners, k) + } + github_com_gogo_protobuf_sortkeys.Strings(keysForOwners) + mapStringForOwners := "map[string]OwnerDesc{" + for _, k := range keysForOwners { + mapStringForOwners += fmt.Sprintf("%v: %v,", k, this.Owners[k]) + } + mapStringForOwners += "}" + s := strings.Join([]string{`&PartitionRingDesc{`, + `Partitions:` + mapStringForPartitions + `,`, + `Owners:` + mapStringForOwners + `,`, + `}`, + }, "") + return s +} +func (this *PartitionDesc) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&PartitionDesc{`, + `Tokens:` + fmt.Sprintf("%v", this.Tokens) + `,`, + `State:` + fmt.Sprintf("%v", this.State) + `,`, + `StateTimestamp:` + fmt.Sprintf("%v", this.StateTimestamp) + `,`, + `Id:` + fmt.Sprintf("%v", this.Id) + `,`, + `}`, + }, "") + return s +} +func (this *OwnerDesc) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&OwnerDesc{`, + `OwnedPartition:` + fmt.Sprintf("%v", this.OwnedPartition) + `,`, + `State:` + fmt.Sprintf("%v", this.State) + `,`, + `UpdatedTimestamp:` + fmt.Sprintf("%v", this.UpdatedTimestamp) + `,`, + `}`, + }, "") + return s +} +func valueToStringPartitionRingDesc(v interface{}) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("*%v", pv) +} +func (m *PartitionRingDesc) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: PartitionRingDesc: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: PartitionRingDesc: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Partitions", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthPartitionRingDesc + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Partitions == nil { + m.Partitions = make(map[int32]PartitionDesc) + } + var mapkey int32 + mapvalue := &PartitionDesc{} + for iNdEx < postIndex { + entryPreIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + if fieldNum == 1 { + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + mapkey |= int32(b&0x7F) << shift + if b < 0x80 { + break + } + } + } else if fieldNum == 2 { + var mapmsglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + mapmsglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if mapmsglen < 0 { + return ErrInvalidLengthPartitionRingDesc + } + postmsgIndex := iNdEx + mapmsglen + if postmsgIndex < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if postmsgIndex > l { + return io.ErrUnexpectedEOF + } + mapvalue = &PartitionDesc{} + if err := mapvalue.Unmarshal(dAtA[iNdEx:postmsgIndex]); err != nil { + return err + } + iNdEx = postmsgIndex + } else { + iNdEx = entryPreIndex + skippy, err := skipPartitionRingDesc(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if (iNdEx + skippy) > postIndex { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + m.Partitions[mapkey] = *mapvalue + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Owners", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthPartitionRingDesc + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Owners == nil { + m.Owners = make(map[string]OwnerDesc) + } + var mapkey string + mapvalue := &OwnerDesc{} + for iNdEx < postIndex { + entryPreIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + if fieldNum == 1 { + var stringLenmapkey uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLenmapkey |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLenmapkey := int(stringLenmapkey) + if intStringLenmapkey < 0 { + return ErrInvalidLengthPartitionRingDesc + } + postStringIndexmapkey := iNdEx + intStringLenmapkey + if postStringIndexmapkey < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if postStringIndexmapkey > l { + return io.ErrUnexpectedEOF + } + mapkey = string(dAtA[iNdEx:postStringIndexmapkey]) + iNdEx = postStringIndexmapkey + } else if fieldNum == 2 { + var mapmsglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + mapmsglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if mapmsglen < 0 { + return ErrInvalidLengthPartitionRingDesc + } + postmsgIndex := iNdEx + mapmsglen + if postmsgIndex < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if postmsgIndex > l { + return io.ErrUnexpectedEOF + } + mapvalue = &OwnerDesc{} + if err := mapvalue.Unmarshal(dAtA[iNdEx:postmsgIndex]); err != nil { + return err + } + iNdEx = postmsgIndex + } else { + iNdEx = entryPreIndex + skippy, err := skipPartitionRingDesc(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if (iNdEx + skippy) > postIndex { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + m.Owners[mapkey] = *mapvalue + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipPartitionRingDesc(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *PartitionDesc) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: PartitionDesc: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: PartitionDesc: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType == 0 { + var v uint32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.Tokens = append(m.Tokens, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthPartitionRingDesc + } + postIndex := iNdEx + packedLen + if postIndex < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + var elementCount int + var count int + for _, integer := range dAtA[iNdEx:postIndex] { + if integer < 128 { + count++ + } + } + elementCount = count + if elementCount != 0 && len(m.Tokens) == 0 { + m.Tokens = make([]uint32, 0, elementCount) + } + for iNdEx < postIndex { + var v uint32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.Tokens = append(m.Tokens, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Tokens", wireType) + } + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field State", wireType) + } + m.State = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.State |= PartitionState(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field StateTimestamp", wireType) + } + m.StateTimestamp = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.StateTimestamp |= int64(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Id", wireType) + } + m.Id = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Id |= int32(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipPartitionRingDesc(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *OwnerDesc) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: OwnerDesc: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: OwnerDesc: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field OwnedPartition", wireType) + } + m.OwnedPartition = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.OwnedPartition |= int32(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field State", wireType) + } + m.State = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.State |= OwnerState(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field UpdatedTimestamp", wireType) + } + m.UpdatedTimestamp = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.UpdatedTimestamp |= int64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipPartitionRingDesc(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthPartitionRingDesc + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipPartitionRingDesc(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthPartitionRingDesc + } + iNdEx += length + if iNdEx < 0 { + return 0, ErrInvalidLengthPartitionRingDesc + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowPartitionRingDesc + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipPartitionRingDesc(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + if iNdEx < 0 { + return 0, ErrInvalidLengthPartitionRingDesc + } + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthPartitionRingDesc = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowPartitionRingDesc = fmt.Errorf("proto: integer overflow") +) diff --git a/vendor/github.com/grafana/dskit/ring/partition_ring_desc.proto b/vendor/github.com/grafana/dskit/ring/partition_ring_desc.proto new file mode 100644 index 000000000000..d8fb9316f01d --- /dev/null +++ b/vendor/github.com/grafana/dskit/ring/partition_ring_desc.proto @@ -0,0 +1,81 @@ +syntax = "proto3"; + +package ring; + +import "github.com/gogo/protobuf/gogoproto/gogo.proto"; + +option (gogoproto.marshaler_all) = true; +option (gogoproto.unmarshaler_all) = true; + +// PartitionRingDesc holds the state of the partitions ring. +message PartitionRingDesc { + // Mapping between partition ID and partition info. + map partitions = 1 [(gogoproto.nullable) = false]; + + // Mapping between instance ID and partition ownership info. + map owners = 2 [(gogoproto.nullable) = false]; +} + +// PartitionDesc holds the state of a single partition. +message PartitionDesc { + // The partition ID. This value is the same as the key in the partitions map in PartitionRingDesc. + int32 id = 4; + + // Unique tokens, generated with deterministic token generator. Tokens MUST be immutable: + // if tokens get changed, the change will not be propagated via memberlist. + repeated uint32 tokens = 1; + + // The state of the partition. + PartitionState state = 2; + + // Unix timestamp (with seconds precision) of when has the state changed last time for this partition. + int64 stateTimestamp = 3; +} + +enum PartitionState { + PartitionUnknown = 0; + + // Pending partition is a partition that is about to be switched to ACTIVE. This state is used + // to let owners to attach to the partition and get ready to handle the partition. + // + // When a partition is in this state, it must not be used for writing or reading. + PartitionPending = 1; + + // Active partition in read-write mode. + PartitionActive = 2; + + // Inactive partition in read-only mode. This partition will be deleted after a grace period, + // unless its state changes to Active again. + PartitionInactive = 3; + + // Deleted partition. This state is not visible to ring clients: it's only used to propagate + // via memberlist the information that a partition has been deleted. + PartitionDeleted = 4; +} + +// OwnerDesc holds the information of a partition owner. +message OwnerDesc { + // Partition that belongs to this owner. A owner can own only 1 partition, but 1 partition can be + // owned by multiple owners. + int32 ownedPartition = 1; + + // The owner state. This field is used to propagate deletions via memberlist. + OwnerState state = 2; + + // Unix timestamp (with seconds precision) of when the data for the owner has been updated the last time. + // This timestamp is used to resolve conflicts when merging updates via memberlist (the most recent + // update wins). + int64 updatedTimestamp = 3; +} + +enum OwnerState { + OwnerUnknown = 0; + + // Active owner. + OwnerActive = 1; + + // Deleted owner. This state is not visible to ring clients: it's only used to propagate + // via memberlist the information that a owner has been deleted. Owners in this state + // are removed before client can see them. + OwnerDeleted = 2; +} diff --git a/vendor/github.com/grafana/dskit/ring/partition_ring_editor.go b/vendor/github.com/grafana/dskit/ring/partition_ring_editor.go new file mode 100644 index 000000000000..a816693e55ca --- /dev/null +++ b/vendor/github.com/grafana/dskit/ring/partition_ring_editor.go @@ -0,0 +1,64 @@ +package ring + +import ( + "context" + "time" + + "github.com/pkg/errors" + + "github.com/grafana/dskit/kv" +) + +// PartitionRingEditor is standalone component that can be used to modify the partitions ring. +// If you want to implement the partition lifecycle you should use PartitionInstanceLifecycler instead. +type PartitionRingEditor struct { + ringKey string + store kv.Client +} + +func NewPartitionRingEditor(ringKey string, store kv.Client) *PartitionRingEditor { + return &PartitionRingEditor{ + ringKey: ringKey, + store: store, + } +} + +// ChangePartitionState changes the partition state to toState. +// This function returns ErrPartitionDoesNotExist if the partition doesn't exist, +// and ErrPartitionStateChangeNotAllowed if the state change is not allowed. +func (l *PartitionRingEditor) ChangePartitionState(ctx context.Context, partitionID int32, toState PartitionState) error { + return l.updateRing(ctx, func(ring *PartitionRingDesc) (bool, error) { + return changePartitionState(ring, partitionID, toState) + }) +} + +func (l *PartitionRingEditor) updateRing(ctx context.Context, update func(ring *PartitionRingDesc) (bool, error)) error { + return l.store.CAS(ctx, l.ringKey, func(in interface{}) (out interface{}, retry bool, err error) { + ringDesc := GetOrCreatePartitionRingDesc(in) + + if changed, err := update(ringDesc); err != nil { + return nil, false, err + } else if !changed { + return nil, false, nil + } + + return ringDesc, true, nil + }) +} + +func changePartitionState(ring *PartitionRingDesc, partitionID int32, toState PartitionState) (changed bool, _ error) { + partition, exists := ring.Partitions[partitionID] + if !exists { + return false, ErrPartitionDoesNotExist + } + + if partition.State == toState { + return false, nil + } + + if !isPartitionStateChangeAllowed(partition.State, toState) { + return false, errors.Wrapf(ErrPartitionStateChangeNotAllowed, "change partition state from %s to %s", partition.State.CleanName(), toState.CleanName()) + } + + return ring.UpdatePartitionState(partitionID, toState, time.Now()), nil +} diff --git a/vendor/github.com/grafana/dskit/ring/partition_ring_http.go b/vendor/github.com/grafana/dskit/ring/partition_ring_http.go new file mode 100644 index 000000000000..8e58c58c7afc --- /dev/null +++ b/vendor/github.com/grafana/dskit/ring/partition_ring_http.go @@ -0,0 +1,158 @@ +package ring + +import ( + "context" + _ "embed" + "fmt" + "html/template" + "net/http" + "sort" + "strconv" + "time" + + "golang.org/x/exp/slices" +) + +//go:embed partition_ring_status.gohtml +var partitionRingPageContent string +var partitionRingPageTemplate = template.Must(template.New("webpage").Funcs(template.FuncMap{ + "mod": func(i, j int32) bool { + return i%j == 0 + }, + "formatTimestamp": func(ts time.Time) string { + return ts.Format("2006-01-02 15:04:05 MST") + }, +}).Parse(partitionRingPageContent)) + +type PartitionRingUpdater interface { + ChangePartitionState(ctx context.Context, partitionID int32, toState PartitionState) error +} + +type PartitionRingPageHandler struct { + reader PartitionRingReader + updater PartitionRingUpdater +} + +func NewPartitionRingPageHandler(reader PartitionRingReader, updater PartitionRingUpdater) *PartitionRingPageHandler { + return &PartitionRingPageHandler{ + reader: reader, + updater: updater, + } +} + +func (h *PartitionRingPageHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case http.MethodGet: + h.handleGetRequest(w, req) + case http.MethodPost: + h.handlePostRequest(w, req) + default: + http.Error(w, "Unsupported HTTP method", http.StatusMethodNotAllowed) + } +} + +func (h *PartitionRingPageHandler) handleGetRequest(w http.ResponseWriter, req *http.Request) { + var ( + ring = h.reader.PartitionRing() + ringDesc = ring.desc + ) + + // Prepare the data to render partitions in the page. + partitionsByID := make(map[int32]partitionPageData, len(ringDesc.Partitions)) + for id, partition := range ringDesc.Partitions { + owners := ring.PartitionOwnerIDsCopy(id) + slices.Sort(owners) + + partitionsByID[id] = partitionPageData{ + ID: id, + Corrupted: false, + State: partition.State, + StateTimestamp: partition.GetStateTime(), + OwnerIDs: owners, + } + } + + // Look for owners of non-existing partitions. We want to provide visibility for such case + // and we report the partition in corrupted state. + for ownerID, owner := range ringDesc.Owners { + partition, exists := partitionsByID[owner.OwnedPartition] + + if !exists { + partition = partitionPageData{ + ID: owner.OwnedPartition, + Corrupted: true, + State: PartitionUnknown, + StateTimestamp: time.Time{}, + OwnerIDs: []string{ownerID}, + } + + partitionsByID[owner.OwnedPartition] = partition + } + + if !slices.Contains(partition.OwnerIDs, ownerID) { + partition.OwnerIDs = append(partition.OwnerIDs, ownerID) + partitionsByID[owner.OwnedPartition] = partition + } + } + + // Covert partitions to a list and sort it by ID. + partitions := make([]partitionPageData, 0, len(partitionsByID)) + + for _, partition := range partitionsByID { + partitions = append(partitions, partition) + } + + sort.Slice(partitions, func(i, j int) bool { + return partitions[i].ID < partitions[j].ID + }) + + renderHTTPResponse(w, partitionRingPageData{ + Partitions: partitions, + PartitionStateChanges: map[PartitionState]PartitionState{ + PartitionPending: PartitionActive, + PartitionActive: PartitionInactive, + PartitionInactive: PartitionActive, + }, + }, partitionRingPageTemplate, req) +} + +func (h *PartitionRingPageHandler) handlePostRequest(w http.ResponseWriter, req *http.Request) { + if req.FormValue("action") == "change_state" { + partitionID, err := strconv.Atoi(req.FormValue("partition_id")) + if err != nil { + http.Error(w, fmt.Sprintf("invalid partition ID: %s", err.Error()), http.StatusBadRequest) + return + } + + toState, ok := PartitionState_value[req.FormValue("partition_state")] + if !ok { + http.Error(w, "invalid partition state", http.StatusBadRequest) + return + } + + if err := h.updater.ChangePartitionState(req.Context(), int32(partitionID), PartitionState(toState)); err != nil { + http.Error(w, fmt.Sprintf("failed to change partition state: %s", err.Error()), http.StatusBadRequest) + return + } + } + + // Implement PRG pattern to prevent double-POST and work with CSRF middleware. + // https://en.wikipedia.org/wiki/Post/Redirect/Get + w.Header().Set("Location", "#") + w.WriteHeader(http.StatusFound) +} + +type partitionRingPageData struct { + Partitions []partitionPageData `json:"partitions"` + + // PartitionStateChanges maps the allowed state changes through the UI. + PartitionStateChanges map[PartitionState]PartitionState `json:"-"` +} + +type partitionPageData struct { + ID int32 `json:"id"` + Corrupted bool `json:"corrupted"` + State PartitionState `json:"state"` + StateTimestamp time.Time `json:"state_timestamp"` + OwnerIDs []string `json:"owner_ids"` +} diff --git a/vendor/github.com/grafana/dskit/ring/partition_ring_model.go b/vendor/github.com/grafana/dskit/ring/partition_ring_model.go new file mode 100644 index 000000000000..c95380756a3c --- /dev/null +++ b/vendor/github.com/grafana/dskit/ring/partition_ring_model.go @@ -0,0 +1,460 @@ +package ring + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/gogo/protobuf/proto" + "golang.org/x/exp/slices" + + "github.com/grafana/dskit/kv/codec" + "github.com/grafana/dskit/kv/memberlist" +) + +type partitionRingCodec struct { + codec.Codec +} + +// Decode wraps Codec.Decode and ensure PartitionRingDesc maps are not nil. +func (c *partitionRingCodec) Decode(in []byte) (interface{}, error) { + out, err := c.Codec.Decode(in) + if err != nil { + return out, err + } + + // Ensure maps are initialised. This makes working with PartitionRingDesc more convenient. + if actual, ok := out.(*PartitionRingDesc); ok { + if actual.Partitions == nil { + actual.Partitions = map[int32]PartitionDesc{} + } + if actual.Owners == nil { + actual.Owners = map[string]OwnerDesc{} + } + } + + return out, nil +} + +func GetPartitionRingCodec() codec.Codec { + return &partitionRingCodec{ + Codec: codec.NewProtoCodec("partitionRingDesc", PartitionRingDescFactory), + } +} + +// PartitionRingDescFactory makes new PartitionRingDesc. +func PartitionRingDescFactory() proto.Message { + return NewPartitionRingDesc() +} + +func GetOrCreatePartitionRingDesc(in any) *PartitionRingDesc { + if in == nil { + return NewPartitionRingDesc() + } + + desc := in.(*PartitionRingDesc) + if desc == nil { + return NewPartitionRingDesc() + } + + return desc +} + +func NewPartitionRingDesc() *PartitionRingDesc { + return &PartitionRingDesc{ + Partitions: map[int32]PartitionDesc{}, + Owners: map[string]OwnerDesc{}, + } +} + +// tokens returns a sort list of tokens registered by all partitions. +func (m *PartitionRingDesc) tokens() Tokens { + allTokens := make(Tokens, 0, len(m.Partitions)*optimalTokensPerInstance) + + for _, partition := range m.Partitions { + allTokens = append(allTokens, partition.Tokens...) + } + + slices.Sort(allTokens) + return allTokens +} + +// partitionByToken returns a map where they key is a registered token and the value is ID of the partition +// that registered that token. +func (m *PartitionRingDesc) partitionByToken() map[Token]int32 { + out := make(map[Token]int32, len(m.Partitions)*optimalTokensPerInstance) + + for partitionID, partition := range m.Partitions { + for _, token := range partition.Tokens { + out[Token(token)] = partitionID + } + } + + return out +} + +// ownersByPartition returns a map where the key is the partition ID and the value is a list of owner IDs. +func (m *PartitionRingDesc) ownersByPartition() map[int32][]string { + out := make(map[int32][]string, len(m.Partitions)) + for id, o := range m.Owners { + out[o.OwnedPartition] = append(out[o.OwnedPartition], id) + } + + // Sort owners to have predictable tests. + for id := range out { + slices.Sort(out[id]) + } + + return out +} + +// countPartitionsByState returns a map containing the number of partitions by state. +func (m *PartitionRingDesc) countPartitionsByState() map[PartitionState]int { + // Init the map to have to zero values for all states. + out := make(map[PartitionState]int, len(PartitionState_value)-2) + for _, state := range PartitionState_value { + if PartitionState(state) == PartitionUnknown || PartitionState(state) == PartitionDeleted { + continue + } + + out[PartitionState(state)] = 0 + } + + for _, partition := range m.Partitions { + out[partition.State]++ + } + + return out +} + +func (m *PartitionRingDesc) activePartitionsCount() int { + count := 0 + for _, partition := range m.Partitions { + if partition.IsActive() { + count++ + } + } + return count +} + +// WithPartitions returns a new PartitionRingDesc with only the specified partitions and their owners included. +func (m *PartitionRingDesc) WithPartitions(partitions map[int32]struct{}) PartitionRingDesc { + newPartitions := make(map[int32]PartitionDesc, len(partitions)) + newOwners := make(map[string]OwnerDesc, len(partitions)*2) // assuming two owners per partition. + + for pid, p := range m.Partitions { + if _, ok := partitions[pid]; ok { + newPartitions[pid] = p + } + } + + for oid, o := range m.Owners { + if _, ok := partitions[o.OwnedPartition]; ok { + newOwners[oid] = o + } + } + + return PartitionRingDesc{ + Partitions: newPartitions, + Owners: newOwners, + } +} + +// AddPartition adds a new partition to the ring. Tokens are auto-generated using the spread minimizing strategy +// which generates deterministic unique tokens. +func (m *PartitionRingDesc) AddPartition(id int32, state PartitionState, now time.Time) { + // Spread-minimizing token generator is deterministic unique-token generator for given id and zone. + // Partitions don't use zones. + spreadMinimizing := NewSpreadMinimizingTokenGeneratorForInstanceAndZoneID("", int(id), 0, false) + + m.Partitions[id] = PartitionDesc{ + Id: id, + Tokens: spreadMinimizing.GenerateTokens(optimalTokensPerInstance, nil), + State: state, + StateTimestamp: now.Unix(), + } +} + +// UpdatePartitionState changes the state of a partition. Returns true if the state was changed, +// or false if the update was a no-op. +func (m *PartitionRingDesc) UpdatePartitionState(id int32, state PartitionState, now time.Time) bool { + d, ok := m.Partitions[id] + if !ok { + return false + } + + if d.State == state { + return false + } + + d.State = state + d.StateTimestamp = now.Unix() + m.Partitions[id] = d + return true +} + +// RemovePartition removes a partition. +func (m *PartitionRingDesc) RemovePartition(id int32) { + delete(m.Partitions, id) +} + +// HasPartition returns whether a partition exists. +func (m *PartitionRingDesc) HasPartition(id int32) bool { + _, ok := m.Partitions[id] + return ok +} + +// AddOrUpdateOwner adds or updates a partition owner in the ring. Returns true, if the +// owner was added or updated, false if it was left unchanged. +func (m *PartitionRingDesc) AddOrUpdateOwner(id string, state OwnerState, ownedPartition int32, now time.Time) bool { + prev, ok := m.Owners[id] + updated := OwnerDesc{ + State: state, + OwnedPartition: ownedPartition, + + // Preserve the previous timestamp so that we'll NOT compare it. + // Then, if we detect that the OwnerDesc should be updated, we'll + // also update the UpdateTimestamp. + UpdatedTimestamp: prev.UpdatedTimestamp, + } + + if ok && prev.Equal(updated) { + return false + } + + updated.UpdatedTimestamp = now.Unix() + m.Owners[id] = updated + + return true +} + +// RemoveOwner removes a partition owner. Returns true if the ring has been changed. +func (m *PartitionRingDesc) RemoveOwner(id string) bool { + if _, ok := m.Owners[id]; !ok { + return false + } + + delete(m.Owners, id) + return true +} + +// HasOwner returns whether a owner exists. +func (m *PartitionRingDesc) HasOwner(id string) bool { + _, ok := m.Owners[id] + return ok +} + +// PartitionOwnersCount returns the number of owners for a given partition. +func (m *PartitionRingDesc) PartitionOwnersCount(partitionID int32) int { + count := 0 + for _, o := range m.Owners { + if o.OwnedPartition == partitionID { + count++ + } + } + return count +} + +// PartitionOwnersCountUpdatedBefore returns the number of owners for a given partition, +// including only owners which have been updated the last time before the input timestamp. +func (m *PartitionRingDesc) PartitionOwnersCountUpdatedBefore(partitionID int32, before time.Time) int { + count := 0 + beforeSeconds := before.Unix() + + for _, o := range m.Owners { + if o.OwnedPartition == partitionID && o.GetUpdatedTimestamp() < beforeSeconds { + count++ + } + } + return count +} + +// Merge implements memberlist.Mergeable. +func (m *PartitionRingDesc) Merge(mergeable memberlist.Mergeable, localCAS bool) (memberlist.Mergeable, error) { + return m.mergeWithTime(mergeable, localCAS, time.Now()) +} + +func (m *PartitionRingDesc) mergeWithTime(mergeable memberlist.Mergeable, localCAS bool, now time.Time) (memberlist.Mergeable, error) { + if mergeable == nil { + return nil, nil + } + + other, ok := mergeable.(*PartitionRingDesc) + if !ok { + return nil, fmt.Errorf("expected *PartitionRingDesc, got %T", mergeable) + } + + if other == nil { + return nil, nil + } + + change := NewPartitionRingDesc() + + // Handle partitions. + for id, otherPart := range other.Partitions { + changed := false + + thisPart, exists := m.Partitions[id] + if !exists { + changed = true + thisPart = otherPart + } else { + // We don't merge changes to partition ID and tokens because we expect them to be immutable. + // + // If in the future we'll change the tokens generation algorithm and we'll have to handle migration to + // a different set of tokens then we'll add the support. For example, we could add "token generation version" + // to PartitionDesc and then preserve tokens generated by latest version only, or a timestamp for tokens + // update too. + + // In case the timestamp is equal we give priority to the deleted state. + // Reason is that timestamp has second precision, so we cover the case an + // update and subsequent deletion occur within the same second. + if otherPart.StateTimestamp > thisPart.StateTimestamp || (otherPart.StateTimestamp == thisPart.StateTimestamp && otherPart.State == PartitionDeleted && thisPart.State != PartitionDeleted) { + changed = true + + thisPart.State = otherPart.State + thisPart.StateTimestamp = otherPart.StateTimestamp + } + } + + if changed { + m.Partitions[id] = thisPart + change.Partitions[id] = thisPart + } + } + + if localCAS { + // Let's mark all missing partitions in incoming change as deleted. + // This breaks commutativity! But we only do it locally, not when gossiping with others. + for pid, thisPart := range m.Partitions { + if _, exists := other.Partitions[pid]; !exists && thisPart.State != PartitionDeleted { + // Partition was removed from the ring. We need to preserve it locally, but we set state to PartitionDeleted. + thisPart.State = PartitionDeleted + thisPart.StateTimestamp = now.Unix() + m.Partitions[pid] = thisPart + change.Partitions[pid] = thisPart + } + } + } + + // Now let's handle owners. + for id, otherOwner := range other.Owners { + thisOwner := m.Owners[id] + + // In case the timestamp is equal we give priority to the deleted state. + // Reason is that timestamp has second precision, so we cover the case an + // update and subsequent deletion occur within the same second. + if otherOwner.UpdatedTimestamp > thisOwner.UpdatedTimestamp || (otherOwner.UpdatedTimestamp == thisOwner.UpdatedTimestamp && otherOwner.State == OwnerDeleted && thisOwner.State != OwnerDeleted) { + m.Owners[id] = otherOwner + change.Owners[id] = otherOwner + } + } + + if localCAS { + // Mark all missing owners as deleted. + // This breaks commutativity! But we only do it locally, not when gossiping with others. + for id, thisOwner := range m.Owners { + if _, exists := other.Owners[id]; !exists && thisOwner.State != OwnerDeleted { + // Owner was removed from the ring. We need to preserve it locally, but we set state to OwnerDeleted. + thisOwner.State = OwnerDeleted + thisOwner.UpdatedTimestamp = now.Unix() + m.Owners[id] = thisOwner + change.Owners[id] = thisOwner + } + } + } + + // If nothing changed, report nothing. + if len(change.Partitions) == 0 && len(change.Owners) == 0 { + return nil, nil + } + + return change, nil +} + +// MergeContent implements memberlist.Mergeable. +func (m *PartitionRingDesc) MergeContent() []string { + result := make([]string, len(m.Partitions)+len(m.Owners)) + + // We're assuming that partition IDs and instance IDs are not colliding (ie. no instance is called "1"). + for pid := range m.Partitions { + result = append(result, strconv.Itoa(int(pid))) + } + + for id := range m.Owners { + result = append(result, id) + } + return result +} + +// RemoveTombstones implements memberlist.Mergeable. +func (m *PartitionRingDesc) RemoveTombstones(limit time.Time) (total, removed int) { + for pid, part := range m.Partitions { + if part.State == PartitionDeleted { + if limit.IsZero() || time.Unix(part.StateTimestamp, 0).Before(limit) { + delete(m.Partitions, pid) + removed++ + } else { + total++ + } + } + } + + for n, owner := range m.Owners { + if owner.State == OwnerDeleted { + if limit.IsZero() || time.Unix(owner.UpdatedTimestamp, 0).Before(limit) { + delete(m.Owners, n) + removed++ + } else { + total++ + } + } + } + + return +} + +// Clone implements memberlist.Mergeable. +func (m *PartitionRingDesc) Clone() memberlist.Mergeable { + clone := proto.Clone(m).(*PartitionRingDesc) + + // Ensure empty maps are preserved (easier to compare with a deep equal in tests). + if m.Partitions != nil && clone.Partitions == nil { + clone.Partitions = map[int32]PartitionDesc{} + } + if m.Owners != nil && clone.Owners == nil { + clone.Owners = map[string]OwnerDesc{} + } + + return clone +} + +func (m *PartitionDesc) IsPending() bool { + return m.GetState() == PartitionPending +} + +func (m *PartitionDesc) IsActive() bool { + return m.GetState() == PartitionActive +} + +func (m *PartitionDesc) IsInactive() bool { + return m.GetState() == PartitionInactive +} + +func (m *PartitionDesc) IsInactiveSince(since time.Time) bool { + return m.IsInactive() && m.GetStateTimestamp() < since.Unix() +} + +func (m *PartitionDesc) GetStateTime() time.Time { + return time.Unix(m.GetStateTimestamp(), 0) +} + +func (m *PartitionDesc) Clone() PartitionDesc { + return *(proto.Clone(m).(*PartitionDesc)) +} + +// CleanName returns the PartitionState name without the "Partition" prefix. +func (s PartitionState) CleanName() string { + return strings.TrimPrefix(s.String(), "Partition") +} diff --git a/vendor/github.com/grafana/dskit/ring/partition_ring_status.gohtml b/vendor/github.com/grafana/dskit/ring/partition_ring_status.gohtml new file mode 100644 index 000000000000..f4f9afe87d88 --- /dev/null +++ b/vendor/github.com/grafana/dskit/ring/partition_ring_status.gohtml @@ -0,0 +1,63 @@ +{{- /*gotype: github.com/grafana/dskit/ring.partitionRingPageData */ -}} + + + + + Partitions Ring Status + + +

Partitions Ring Status

+ + + + + + + + + + + + + {{ $stateChanges := .PartitionStateChanges }} + {{ range $partition := .Partitions }} + + + + + + + + {{ end }} + +
Partition IDStateState updated atOwnersActions
{{ .ID }} + {{ if .Corrupted }} + Corrupt + {{ else }} + {{ .State.CleanName }} + {{ end }} + + {{ if not .StateTimestamp.IsZero }} + {{ .StateTimestamp | formatTimestamp }} + {{ else }} + N/A + {{ end }} + + {{ range $ownerID := $partition.OwnerIDs }} + {{$ownerID}}
+ {{ end }} +
+ + {{ if and (not .Corrupted) (ne (index $stateChanges .State) 0) }} + {{ $toState := index $stateChanges .State }} +
+ + + + + +
+ {{ end }} +
+ + \ No newline at end of file diff --git a/vendor/github.com/grafana/dskit/ring/partition_ring_watcher.go b/vendor/github.com/grafana/dskit/ring/partition_ring_watcher.go new file mode 100644 index 000000000000..39225697eb0e --- /dev/null +++ b/vendor/github.com/grafana/dskit/ring/partition_ring_watcher.go @@ -0,0 +1,100 @@ +package ring + +import ( + "context" + "sync" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/grafana/dskit/kv" + "github.com/grafana/dskit/services" +) + +// PartitionRingWatcher watches the partitions ring for changes in the KV store. +type PartitionRingWatcher struct { + services.Service + + key string + kv kv.Client + logger log.Logger + + ringMx sync.Mutex + ring *PartitionRing + + // Metrics. + numPartitionsGaugeVec *prometheus.GaugeVec +} + +func NewPartitionRingWatcher(name, key string, kv kv.Client, logger log.Logger, reg prometheus.Registerer) *PartitionRingWatcher { + r := &PartitionRingWatcher{ + key: key, + kv: kv, + logger: logger, + ring: NewPartitionRing(*NewPartitionRingDesc()), + numPartitionsGaugeVec: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ + Name: "partition_ring_partitions", + Help: "Number of partitions by state in the partitions ring.", + ConstLabels: map[string]string{"name": name}, + }, []string{"state"}), + } + + r.Service = services.NewBasicService(r.starting, r.loop, nil).WithName("partitions-ring-watcher") + return r +} + +func (w *PartitionRingWatcher) starting(ctx context.Context) error { + // Get the initial ring state so that, as soon as the service will be running, the in-memory + // ring would be already populated and there's no race condition between when the service is + // running and the WatchKey() callback is called for the first time. + value, err := w.kv.Get(ctx, w.key) + if err != nil { + return errors.Wrap(err, "unable to initialise ring state") + } + + if value == nil { + level.Info(w.logger).Log("msg", "partition ring doesn't exist in KV store yet") + value = NewPartitionRingDesc() + } + + w.updatePartitionRing(value.(*PartitionRingDesc)) + return nil +} + +func (w *PartitionRingWatcher) loop(ctx context.Context) error { + w.kv.WatchKey(ctx, w.key, func(value interface{}) bool { + if value == nil { + level.Info(w.logger).Log("msg", "partition ring doesn't exist in KV store yet") + return true + } + + w.updatePartitionRing(value.(*PartitionRingDesc)) + return true + }) + return nil +} + +func (w *PartitionRingWatcher) updatePartitionRing(desc *PartitionRingDesc) { + newRing := NewPartitionRing(*desc) + + w.ringMx.Lock() + w.ring = newRing + w.ringMx.Unlock() + + // Update metrics. + for state, count := range desc.countPartitionsByState() { + w.numPartitionsGaugeVec.WithLabelValues(state.CleanName()).Set(float64(count)) + } +} + +// PartitionRing returns the most updated snapshot of the PartitionRing. The returned instance +// is immutable and will not be updated if new changes are done to the ring. +func (w *PartitionRingWatcher) PartitionRing() *PartitionRing { + w.ringMx.Lock() + defer w.ringMx.Unlock() + + return w.ring +} diff --git a/vendor/github.com/grafana/dskit/ring/partitions_ring_shuffle_shard_cache.go b/vendor/github.com/grafana/dskit/ring/partitions_ring_shuffle_shard_cache.go new file mode 100644 index 000000000000..ce80d2c14adc --- /dev/null +++ b/vendor/github.com/grafana/dskit/ring/partitions_ring_shuffle_shard_cache.go @@ -0,0 +1,96 @@ +package ring + +import ( + "math" + "sync" + "time" +) + +type partitionRingShuffleShardCache struct { + mtx sync.RWMutex + cacheWithoutLookback map[subringCacheKey]*PartitionRing + cacheWithLookback map[subringCacheKey]cachedSubringWithLookback[*PartitionRing] +} + +func newPartitionRingShuffleShardCache() *partitionRingShuffleShardCache { + return &partitionRingShuffleShardCache{ + cacheWithoutLookback: map[subringCacheKey]*PartitionRing{}, + cacheWithLookback: map[subringCacheKey]cachedSubringWithLookback[*PartitionRing]{}, + } +} + +func (r *partitionRingShuffleShardCache) setSubring(identifier string, size int, subring *PartitionRing) { + if subring == nil { + return + } + + r.mtx.Lock() + defer r.mtx.Unlock() + + r.cacheWithoutLookback[subringCacheKey{identifier: identifier, shardSize: size}] = subring +} + +func (r *partitionRingShuffleShardCache) getSubring(identifier string, size int) *PartitionRing { + r.mtx.RLock() + defer r.mtx.RUnlock() + + cached := r.cacheWithoutLookback[subringCacheKey{identifier: identifier, shardSize: size}] + if cached == nil { + return nil + } + + return cached +} + +func (r *partitionRingShuffleShardCache) setSubringWithLookback(identifier string, size int, lookbackPeriod time.Duration, now time.Time, subring *PartitionRing) { + if subring == nil { + return + } + + var ( + lookbackWindowStart = now.Add(-lookbackPeriod).Unix() + validForLookbackWindowsStartingBefore = int64(math.MaxInt64) + ) + + for _, partition := range subring.desc.Partitions { + stateChangedDuringLookbackWindow := partition.StateTimestamp >= lookbackWindowStart + + if stateChangedDuringLookbackWindow && partition.StateTimestamp < validForLookbackWindowsStartingBefore { + validForLookbackWindowsStartingBefore = partition.StateTimestamp + } + } + + r.mtx.Lock() + defer r.mtx.Unlock() + + // Only update cache if subring's lookback window starts later than the previously cached subring for this identifier, + // if there is one. This prevents cache thrashing due to different calls competing if their lookback windows start + // before and after the time a partition state has changed. + key := subringCacheKey{identifier: identifier, shardSize: size, lookbackPeriod: lookbackPeriod} + + if existingEntry, haveCached := r.cacheWithLookback[key]; !haveCached || existingEntry.validForLookbackWindowsStartingAfter < lookbackWindowStart { + r.cacheWithLookback[key] = cachedSubringWithLookback[*PartitionRing]{ + subring: subring, + validForLookbackWindowsStartingAfter: lookbackWindowStart, + validForLookbackWindowsStartingBefore: validForLookbackWindowsStartingBefore, + } + } +} + +func (r *partitionRingShuffleShardCache) getSubringWithLookback(identifier string, size int, lookbackPeriod time.Duration, now time.Time) *PartitionRing { + r.mtx.RLock() + defer r.mtx.RUnlock() + + cached, ok := r.cacheWithLookback[subringCacheKey{identifier: identifier, shardSize: size, lookbackPeriod: lookbackPeriod}] + if !ok { + return nil + } + + lookbackWindowStart := now.Add(-lookbackPeriod).Unix() + if lookbackWindowStart < cached.validForLookbackWindowsStartingAfter || lookbackWindowStart > cached.validForLookbackWindowsStartingBefore { + // The cached subring is not valid for the lookback window that has been requested. + return nil + } + + return cached.subring +} diff --git a/vendor/github.com/grafana/dskit/ring/replication_set.go b/vendor/github.com/grafana/dskit/ring/replication_set.go index f05153c0525c..ffdcf80ab526 100644 --- a/vendor/github.com/grafana/dskit/ring/replication_set.go +++ b/vendor/github.com/grafana/dskit/ring/replication_set.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "sort" + "sync" "time" kitlog "github.com/go-kit/log" @@ -388,6 +389,111 @@ func DoUntilQuorumWithoutSuccessfulContextCancellation[T any](ctx context.Contex return results, nil } +// DoMultiUntilQuorumWithoutSuccessfulContextCancellation behaves similar to DoUntilQuorumWithoutSuccessfulContextCancellation +// with the following exceptions: +// +// - This function calls DoUntilQuorumWithoutSuccessfulContextCancellation for each input ReplicationSet and requires +// DoUntilQuorumWithoutSuccessfulContextCancellation to successfully run for each of them. Execution breaks on the +// first error returned by DoUntilQuorumWithoutSuccessfulContextCancellation on any ReplicationSet. +// +// - This function requires that the callback function f always call context.CancelCauseFunc once done. Failing to +// cancel the context will leak resources. +func DoMultiUntilQuorumWithoutSuccessfulContextCancellation[T any](ctx context.Context, sets []ReplicationSet, cfg DoUntilQuorumConfig, f func(context.Context, *InstanceDesc, context.CancelCauseFunc) (T, error), cleanupFunc func(T)) ([]T, error) { + if len(sets) == 0 { + return nil, errors.New("no replication sets") + } + if len(sets) == 1 { + return DoUntilQuorumWithoutSuccessfulContextCancellation[T](ctx, sets[0], cfg, f, cleanupFunc) + } + + results, _, err := doMultiUntilQuorumWithoutSuccessfulContextCancellation[T](ctx, sets, cfg, f, cleanupFunc) + return results, err +} + +// See DoMultiUntilQuorumWithoutSuccessfulContextCancellation(). +// +// The returned context.Context is the internal context used by workers and it's used for testing purposes. +func doMultiUntilQuorumWithoutSuccessfulContextCancellation[T any](ctx context.Context, sets []ReplicationSet, cfg DoUntilQuorumConfig, f func(context.Context, *InstanceDesc, context.CancelCauseFunc) (T, error), cleanupFunc func(T)) ([]T, context.Context, error) { + var ( + returnResultsMx = sync.Mutex{} + returnResults = make([]T, 0, len(sets)*len(sets[0].Instances)) // Assume all replication sets have the same number of instances. + + returnErrOnce sync.Once + returnErr error // The first error occurred. + + workersGroup = sync.WaitGroup{} + workersCtx, cancelWorkersCtx = context.WithCancelCause(ctx) + + inflightTracker = newInflightInstanceTracker(sets) + ) + + cancelWorkersCtxIfSafe := func() { + if inflightTracker.allInstancesCompleted() { + cancelWorkersCtx(errors.New("all requests completed")) + } + } + + // Start a worker for each set. A worker is responsible to call DoUntilQuorumWithoutSuccessfulContextCancellation() + // for the given replication set and handle the result. + workersGroup.Add(len(sets)) + + for idx, set := range sets { + go func(idx int, set ReplicationSet) { + defer workersGroup.Done() + + wrappedFn := func(ctx context.Context, instance *InstanceDesc, cancelCtx context.CancelCauseFunc) (T, error) { + // The callback function has been called, so we need to track it. + inflightTracker.addInstance(idx, instance) + + // Inject custom logic in the context.CancelCauseFunc. + return f(ctx, instance, func(cause error) { + // Call the original one. + cancelCtx(cause) + + // The callback has done, so we can remove it from tracker and then check if it's safe + // to cancel the workers context. + inflightTracker.removeInstance(idx, instance) + cancelWorkersCtxIfSafe() + }) + } + + setResults, setErr := DoUntilQuorumWithoutSuccessfulContextCancellation[T](workersCtx, set, cfg, wrappedFn, cleanupFunc) + + if setErr != nil { + returnErrOnce.Do(func() { + returnErr = setErr + + // Interrupt the execution of all workers. + cancelWorkersCtx(setErr) + }) + + return + } + + // Keep track of the results. + returnResultsMx.Lock() + returnResults = append(returnResults, setResults...) + returnResultsMx.Unlock() + }(idx, set) + } + + // Wait until all goroutines have terminated. + workersGroup.Wait() + + // All workers completed, so it's guaranteed returnResults and returnErr won't be accessed by workers anymore, + // and it's safe to read them with no locking. + if returnErr != nil { + return nil, workersCtx, returnErr + } + + // No error occurred. It means workers context hasn't been canceled yet, and we don't expect more callbacks + // to get tracked, so we can check if the cancelling condition has already been reached and eventually do it. + inflightTracker.allInstancesAdded() + cancelWorkersCtxIfSafe() + + return returnResults, workersCtx, nil +} + type instanceResult[T any] struct { result T err error @@ -405,6 +511,16 @@ func (r ReplicationSet) Includes(addr string) bool { return false } +// GetIDs returns the IDs of all instances within the replication set. Returned slice +// order is not guaranteed. +func (r ReplicationSet) GetIDs() []string { + ids := make([]string, 0, len(r.Instances)) + for _, desc := range r.Instances { + ids = append(ids, desc.Id) + } + return ids +} + // GetAddresses returns the addresses of all instances within the replication set. Returned slice // order is not guaranteed. func (r ReplicationSet) GetAddresses() []string { @@ -468,6 +584,17 @@ func HasReplicationSetChangedWithoutState(before, after ReplicationSet) bool { }) } +// Has HasReplicationSetChangedWithoutStateOrAddr returns false if two replications sets +// are the same (with possibly different timestamps, instance states, and ip addresses), +// true if they differ in any other way (number of instances, tokens, zones, ...). +func HasReplicationSetChangedWithoutStateOrAddr(before, after ReplicationSet) bool { + return hasReplicationSetChangedExcluding(before, after, func(i *InstanceDesc) { + i.Timestamp = 0 + i.State = PENDING + i.Addr = "" + }) +} + // Do comparison of replicasets, but apply a function first // to be able to exclude (reset) some values func hasReplicationSetChangedExcluding(before, after ReplicationSet, exclude func(*InstanceDesc)) bool { @@ -478,8 +605,8 @@ func hasReplicationSetChangedExcluding(before, after ReplicationSet, exclude fun return true } - sort.Sort(ByAddr(beforeInstances)) - sort.Sort(ByAddr(afterInstances)) + sort.Sort(ByID(beforeInstances)) + sort.Sort(ByID(afterInstances)) for i := 0; i < len(beforeInstances); i++ { b := beforeInstances[i] diff --git a/vendor/github.com/grafana/dskit/ring/replication_set_tracker.go b/vendor/github.com/grafana/dskit/ring/replication_set_tracker.go index 202b568bb956..73da1bc37f8a 100644 --- a/vendor/github.com/grafana/dskit/ring/replication_set_tracker.go +++ b/vendor/github.com/grafana/dskit/ring/replication_set_tracker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "math/rand" + "sync" "github.com/go-kit/log" "github.com/go-kit/log/level" @@ -465,3 +466,91 @@ func (t *zoneAwareContextTracker) cancelAllContexts(cause error) { delete(t.cancelFuncs, instance) } } + +type inflightInstanceTracker struct { + mx sync.Mutex + inflight [][]*InstanceDesc + + // expectMoreInstances is true if more instances are expected to be added to the tracker. + expectMoreInstances bool +} + +func newInflightInstanceTracker(sets []ReplicationSet) *inflightInstanceTracker { + // Init the inflight tracker. + inflight := make([][]*InstanceDesc, len(sets)) + for idx, set := range sets { + inflight[idx] = make([]*InstanceDesc, 0, len(set.Instances)) + } + + return &inflightInstanceTracker{ + inflight: inflight, + expectMoreInstances: true, + } +} + +// addInstance adds the instance for replicationSetIdx to the tracker. +// +// addInstance is idempotent. +func (t *inflightInstanceTracker) addInstance(replicationSetIdx int, instance *InstanceDesc) { + t.mx.Lock() + defer t.mx.Unlock() + + // Check if the instance has already been added. + for _, curr := range t.inflight[replicationSetIdx] { + if curr == instance { + return + } + } + + t.inflight[replicationSetIdx] = append(t.inflight[replicationSetIdx], instance) +} + +// removeInstance removes the instance for replicationSetIdx from the tracker. +// +// removeInstance is idempotent. +func (t *inflightInstanceTracker) removeInstance(replicationSetIdx int, instance *InstanceDesc) { + t.mx.Lock() + defer t.mx.Unlock() + + for i, curr := range t.inflight[replicationSetIdx] { + if curr == instance { + instances := t.inflight[replicationSetIdx] + t.inflight[replicationSetIdx] = append(instances[:i], instances[i+1:]...) + + // We can safely break the loop because we don't expect multiple occurrences of the same instance. + return + } + } +} + +// allInstancesAdded signals the tracker that all expected instances have been added. +// +// allInstancesAdded is idempotent. +func (t *inflightInstanceTracker) allInstancesAdded() { + t.mx.Lock() + defer t.mx.Unlock() + + t.expectMoreInstances = false +} + +// allInstancesCompleted returns true if and only if no more instances are expected to be +// added to the tracker and all previously tracked instances have been removed calling removeInstance(). +func (t *inflightInstanceTracker) allInstancesCompleted() bool { + t.mx.Lock() + defer t.mx.Unlock() + + // We can't assert all instances have completed if it's still possible + // to add new ones to the tracker. + if t.expectMoreInstances { + return false + } + + // Ensure there are no inflight instances for any replication set. + for _, instances := range t.inflight { + if len(instances) > 0 { + return false + } + } + + return true +} diff --git a/vendor/github.com/grafana/dskit/ring/replication_strategy.go b/vendor/github.com/grafana/dskit/ring/replication_strategy.go index 44e05a53833c..db2b283548f2 100644 --- a/vendor/github.com/grafana/dskit/ring/replication_strategy.go +++ b/vendor/github.com/grafana/dskit/ring/replication_strategy.go @@ -109,11 +109,3 @@ func (r *Ring) IsHealthy(instance *InstanceDesc, op Operation, now time.Time) bo func (r *Ring) ReplicationFactor() int { return r.cfg.ReplicationFactor } - -// InstancesCount returns the number of instances in the ring. -func (r *Ring) InstancesCount() int { - r.mtx.RLock() - c := len(r.ringDesc.Ingesters) - r.mtx.RUnlock() - return c -} diff --git a/vendor/github.com/grafana/dskit/ring/ring.go b/vendor/github.com/grafana/dskit/ring/ring.go index 0c54bb1c5433..e1c1f6a5159d 100644 --- a/vendor/github.com/grafana/dskit/ring/ring.go +++ b/vendor/github.com/grafana/dskit/ring/ring.go @@ -58,6 +58,9 @@ type ReadRing interface { // InstancesCount returns the number of instances in the ring. InstancesCount() int + // InstancesWithTokensCount returns the number of instances in the ring that have tokens. + InstancesWithTokensCount() int + // ShuffleShard returns a subring for the provided identifier (eg. a tenant ID) // and size (number of instances). ShuffleShard(identifier string, size int) ReadRing @@ -78,6 +81,15 @@ type ReadRing interface { // GetTokenRangesForInstance returns the token ranges owned by an instance in the ring GetTokenRangesForInstance(instanceID string) (TokenRanges, error) + + // InstancesInZoneCount returns the number of instances in the ring that are registered in given zone. + InstancesInZoneCount(zone string) int + + // InstancesWithTokensInZoneCount returns the number of instances in the ring that are registered in given zone and have tokens. + InstancesWithTokensInZoneCount(zone string) int + + // ZonesCount returns the number of zones for which there's at least 1 instance registered in the ring. + ZonesCount() int } var ( @@ -184,10 +196,19 @@ type Ring struct { // to be sorted alphabetically. ringZones []string + // Number of registered instances with tokens. + instancesWithTokensCount int + + // Number of registered instances per zone. + instancesCountPerZone map[string]int + + // Nubmber of registered instances with tokens per zone. + instancesWithTokensCountPerZone map[string]int + // Cache of shuffle-sharded subrings per identifier. Invalidated when topology changes. // If set to nil, no caching is done (used by tests, and subrings). shuffledSubringCache map[subringCacheKey]*Ring - shuffledSubringWithLookbackCache map[subringCacheKey]cachedSubringWithLookback + shuffledSubringWithLookbackCache map[subringCacheKey]cachedSubringWithLookback[*Ring] numMembersGaugeVec *prometheus.GaugeVec totalTokensGauge prometheus.Gauge @@ -202,8 +223,8 @@ type subringCacheKey struct { lookbackPeriod time.Duration } -type cachedSubringWithLookback struct { - subring *Ring +type cachedSubringWithLookback[R any] struct { + subring R validForLookbackWindowsStartingAfter int64 // if the lookback window is from T to S, validForLookbackWindowsStartingAfter is the earliest value of T this cache entry is valid for validForLookbackWindowsStartingBefore int64 // if the lookback window is from T to S, validForLookbackWindowsStartingBefore is the latest value of T this cache entry is valid for } @@ -237,7 +258,7 @@ func NewWithStoreClientAndStrategy(cfg Config, name, key string, store kv.Client strategy: strategy, ringDesc: &Desc{}, shuffledSubringCache: map[subringCacheKey]*Ring{}, - shuffledSubringWithLookbackCache: map[subringCacheKey]cachedSubringWithLookback{}, + shuffledSubringWithLookbackCache: map[subringCacheKey]cachedSubringWithLookback[*Ring]{}, numMembersGaugeVec: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ Name: "ring_members", Help: "Number of members in the ring", @@ -333,6 +354,9 @@ func (r *Ring) updateRingState(ringDesc *Desc) { ringInstanceByToken := ringDesc.getTokensInfo() ringZones := getZones(ringTokensByZone) oldestRegisteredTimestamp := ringDesc.getOldestRegisteredTimestamp() + instancesWithTokensCount := ringDesc.instancesWithTokensCount() + instancesCountPerZone := ringDesc.instancesCountPerZone() + instancesWithTokensCountPerZone := ringDesc.instancesWithTokensCountPerZone() r.mtx.Lock() defer r.mtx.Unlock() @@ -341,6 +365,9 @@ func (r *Ring) updateRingState(ringDesc *Desc) { r.ringTokensByZone = ringTokensByZone r.ringInstanceByToken = ringInstanceByToken r.ringZones = ringZones + r.instancesWithTokensCount = instancesWithTokensCount + r.instancesCountPerZone = instancesCountPerZone + r.instancesWithTokensCountPerZone = instancesWithTokensCountPerZone r.oldestRegisteredTimestamp = oldestRegisteredTimestamp r.lastTopologyChange = now @@ -349,7 +376,7 @@ func (r *Ring) updateRingState(ringDesc *Desc) { r.shuffledSubringCache = make(map[subringCacheKey]*Ring) } if r.shuffledSubringWithLookbackCache != nil { - r.shuffledSubringWithLookbackCache = make(map[subringCacheKey]cachedSubringWithLookback) + r.shuffledSubringWithLookbackCache = make(map[subringCacheKey]cachedSubringWithLookback[*Ring]) } r.updateRingMetrics(rc) @@ -676,7 +703,7 @@ func (r *Ring) ShuffleShard(identifier string, size int) ReadRing { // operations (read only). // // This function supports caching, but the cache will only be effective if successive calls for the -// same identifier are for increasing values of (now-lookbackPeriod). +// same identifier are with the same lookbackPeriod and increasing values of now. func (r *Ring) ShuffleShardWithLookback(identifier string, size int, lookbackPeriod time.Duration, now time.Time) ReadRing { // Nothing to do if the shard size is not smaller then the actual ring. if size <= 0 || r.InstancesCount() <= size { @@ -797,12 +824,15 @@ func (r *Ring) shuffleShard(identifier string, size int, lookbackPeriod time.Dur shardTokens := mergeTokenGroups(shardTokensByZone) return &Ring{ - cfg: r.cfg, - strategy: r.strategy, - ringDesc: shardDesc, - ringTokens: shardTokens, - ringTokensByZone: shardTokensByZone, - ringZones: getZones(shardTokensByZone), + cfg: r.cfg, + strategy: r.strategy, + ringDesc: shardDesc, + ringTokens: shardTokens, + ringTokensByZone: shardTokensByZone, + ringZones: getZones(shardTokensByZone), + instancesWithTokensCount: shardDesc.instancesWithTokensCount(), + instancesCountPerZone: shardDesc.instancesCountPerZone(), + instancesWithTokensCountPerZone: shardDesc.instancesWithTokensCountPerZone(), oldestRegisteredTimestamp: shardDesc.getOldestRegisteredTimestamp(), @@ -866,16 +896,32 @@ func mergeTokenGroups(groupsByName map[string][]uint32) []uint32 { return merged } -// GetInstanceState returns the current state of an instance or an error if the -// instance does not exist in the ring. -func (r *Ring) GetInstanceState(instanceID string) (InstanceState, error) { +// GetInstance return the InstanceDesc for the given instanceID or an error +// if the instance doesn't exist in the ring. The returned InstanceDesc is NOT a +// deep copy, so the caller should never modify it. +func (r *Ring) GetInstance(instanceID string) (doNotModify InstanceDesc, _ error) { r.mtx.RLock() defer r.mtx.RUnlock() instances := r.ringDesc.GetIngesters() + if instances == nil { + return InstanceDesc{}, ErrInstanceNotFound + } + instance, ok := instances[instanceID] if !ok { - return PENDING, ErrInstanceNotFound + return InstanceDesc{}, ErrInstanceNotFound + } + + return instance, nil +} + +// GetInstanceState returns the current state of an instance or an error if the +// instance does not exist in the ring. +func (r *Ring) GetInstanceState(instanceID string) (InstanceState, error) { + instance, err := r.GetInstance(instanceID) + if err != nil { + return PENDING, err } return instance.GetState(), nil @@ -1017,7 +1063,7 @@ func (r *Ring) setCachedShuffledSubringWithLookback(identifier string, size int, key := subringCacheKey{identifier: identifier, shardSize: size, lookbackPeriod: lookbackPeriod} if existingEntry, haveCached := r.shuffledSubringWithLookbackCache[key]; !haveCached || existingEntry.validForLookbackWindowsStartingAfter < lookbackWindowStart { - r.shuffledSubringWithLookbackCache[key] = cachedSubringWithLookback{ + r.shuffledSubringWithLookbackCache[key] = cachedSubringWithLookback[*Ring]{ subring: subring, validForLookbackWindowsStartingAfter: lookbackWindowStart, validForLookbackWindowsStartingBefore: validForLookbackWindowsStartingBefore, @@ -1063,6 +1109,45 @@ func (r *Ring) ServeHTTP(w http.ResponseWriter, req *http.Request) { newRingPageHandler(r, r.cfg.HeartbeatTimeout).handle(w, req) } +// InstancesCount returns the number of instances in the ring. +func (r *Ring) InstancesCount() int { + r.mtx.RLock() + c := len(r.ringDesc.Ingesters) + r.mtx.RUnlock() + return c +} + +// InstancesWithTokensCount returns the number of instances in the ring that have tokens. +func (r *Ring) InstancesWithTokensCount() int { + r.mtx.RLock() + defer r.mtx.RUnlock() + + return r.instancesWithTokensCount +} + +// InstancesInZoneCount returns the number of instances in the ring that are registered in given zone. +func (r *Ring) InstancesInZoneCount(zone string) int { + r.mtx.RLock() + defer r.mtx.RUnlock() + + return r.instancesCountPerZone[zone] +} + +// InstancesWithTokensInZoneCount returns the number of instances in the ring that are registered in given zone and have tokens. +func (r *Ring) InstancesWithTokensInZoneCount(zone string) int { + r.mtx.RLock() + defer r.mtx.RUnlock() + + return r.instancesWithTokensCountPerZone[zone] +} + +func (r *Ring) ZonesCount() int { + r.mtx.RLock() + defer r.mtx.RUnlock() + + return len(r.ringZones) +} + // Operation describes which instances can be included in the replica set, based on their state. // // Implemented as bitmap, with upper 16-bits used for encoding extendReplicaSet, and lower 16-bits used for encoding healthy states. diff --git a/vendor/github.com/grafana/dskit/ring/http.go b/vendor/github.com/grafana/dskit/ring/ring_http.go similarity index 96% rename from vendor/github.com/grafana/dskit/ring/http.go rename to vendor/github.com/grafana/dskit/ring/ring_http.go index e70b3e6f0a1f..7300430ddac1 100644 --- a/vendor/github.com/grafana/dskit/ring/http.go +++ b/vendor/github.com/grafana/dskit/ring/ring_http.go @@ -13,7 +13,7 @@ import ( "time" ) -//go:embed status.gohtml +//go:embed ring_status.gohtml var defaultPageContent string var defaultPageTemplate = template.Must(template.New("webpage").Funcs(template.FuncMap{ "mod": func(i, j int) bool { return i%j == 0 }, @@ -134,7 +134,7 @@ func (h *ringPageHandler) handle(w http.ResponseWriter, req *http.Request) { // RenderHTTPResponse either responds with json or a rendered html page using the passed in template // by checking the Accepts header -func renderHTTPResponse(w http.ResponseWriter, v httpResponse, t *template.Template, r *http.Request) { +func renderHTTPResponse(w http.ResponseWriter, v any, t *template.Template, r *http.Request) { accept := r.Header.Get("Accept") if strings.Contains(accept, "application/json") { writeJSONResponse(w, v) @@ -161,7 +161,7 @@ func (h *ringPageHandler) forget(ctx context.Context, id string) error { } // WriteJSONResponse writes some JSON as a HTTP response. -func writeJSONResponse(w http.ResponseWriter, v httpResponse) { +func writeJSONResponse(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(v); err != nil { diff --git a/vendor/github.com/grafana/dskit/ring/status.gohtml b/vendor/github.com/grafana/dskit/ring/ring_status.gohtml similarity index 100% rename from vendor/github.com/grafana/dskit/ring/status.gohtml rename to vendor/github.com/grafana/dskit/ring/ring_status.gohtml diff --git a/vendor/github.com/grafana/dskit/ring/spread_minimizing_token_generator.go b/vendor/github.com/grafana/dskit/ring/spread_minimizing_token_generator.go index 2363825076fc..bd2ed9970a59 100644 --- a/vendor/github.com/grafana/dskit/ring/spread_minimizing_token_generator.go +++ b/vendor/github.com/grafana/dskit/ring/spread_minimizing_token_generator.go @@ -8,10 +8,6 @@ import ( "sort" "strconv" - "github.com/go-kit/log" - "github.com/go-kit/log/level" - "github.com/pkg/errors" - "golang.org/x/exp/slices" ) @@ -22,11 +18,10 @@ const ( ) var ( - instanceIDRegex = regexp.MustCompile(`^(.*)-(\d+)$`) + instanceIDRegex = regexp.MustCompile(`^(.*-)(\d+)$`) errorBadInstanceIDFormat = func(instanceID string) error { return fmt.Errorf("unable to extract instance id from %q", instanceID) } - errorNoPreviousInstance = fmt.Errorf("impossible to find the instance preceding the target instance, because it is the first instance") errorMissingPreviousInstance = func(requiredInstanceID string) error { return fmt.Errorf("the instance %q has not been registered to the ring or has no tokens yet", requiredInstanceID) @@ -49,15 +44,13 @@ var ( ) type SpreadMinimizingTokenGenerator struct { - instanceID int - instance string - zoneID int - spreadMinimizingZones []string - canJoinEnabled bool - logger log.Logger + instanceID int + instancePrefix string + zoneID int + canJoinEnabled bool } -func NewSpreadMinimizingTokenGenerator(instance, zone string, spreadMinimizingZones []string, canJoinEnabled bool, logger log.Logger) (*SpreadMinimizingTokenGenerator, error) { +func NewSpreadMinimizingTokenGenerator(instance, zone string, spreadMinimizingZones []string, canJoinEnabled bool) (*SpreadMinimizingTokenGenerator, error) { if len(spreadMinimizingZones) <= 0 || len(spreadMinimizingZones) > maxZonesCount { return nil, errorZoneCountOutOfBound(len(spreadMinimizingZones)) } @@ -66,52 +59,35 @@ func NewSpreadMinimizingTokenGenerator(instance, zone string, spreadMinimizingZo if !slices.IsSorted(sortedZones) { sort.Strings(sortedZones) } - instanceID, err := parseInstanceID(instance) + zoneID, err := findZoneID(zone, sortedZones) if err != nil { return nil, err } - zoneID, err := findZoneID(zone, sortedZones) + + prefix, instanceID, err := parseInstanceID(instance) if err != nil { return nil, err } - tokenGenerator := &SpreadMinimizingTokenGenerator{ - instanceID: instanceID, - instance: instance, - zoneID: zoneID, - spreadMinimizingZones: sortedZones, - canJoinEnabled: canJoinEnabled, - logger: logger, - } - return tokenGenerator, nil + return NewSpreadMinimizingTokenGeneratorForInstanceAndZoneID(prefix, instanceID, zoneID, canJoinEnabled), nil } -func parseInstanceID(instanceID string) (int, error) { - parts := instanceIDRegex.FindStringSubmatch(instanceID) - if len(parts) != 3 { - return 0, errorBadInstanceIDFormat(instanceID) +func NewSpreadMinimizingTokenGeneratorForInstanceAndZoneID(instancePrefix string, instanceID, zoneID int, canJoinEnabled bool) *SpreadMinimizingTokenGenerator { + return &SpreadMinimizingTokenGenerator{ + instanceID: instanceID, + instancePrefix: instancePrefix, + zoneID: zoneID, + canJoinEnabled: canJoinEnabled, } - return strconv.Atoi(parts[2]) } -// previousInstance determines the string id of the instance preceding the given instance string id. -// If it is impossible to parse the given instanceID, or it is impossible to determine its predecessor -// because the passed instanceID has a bad format, or has no predecessor, an error is returned. -// For examples, my-instance-1 is preceded by instance my-instance-0, but my-instance-0 has no -// predecessor because its index is 0. -func previousInstance(instanceID string) (string, error) { +func parseInstanceID(instanceID string) (string, int, error) { parts := instanceIDRegex.FindStringSubmatch(instanceID) if len(parts) != 3 { - return "", errorBadInstanceIDFormat(instanceID) - } - id, err := strconv.Atoi(parts[2]) - if err != nil { - return "", err - } - if id == 0 { - return "", errorNoPreviousInstance + return "", 0, errorBadInstanceIDFormat(instanceID) } - return fmt.Sprintf("%s-%d", parts[1], id-1), nil + val, err := strconv.Atoi(parts[2]) + return parts[1], val, err } // findZoneID gets a zone name and a slice of sorted zones, @@ -193,7 +169,11 @@ func (t *SpreadMinimizingTokenGenerator) GenerateTokens(requestedTokensCount int used[v] = true } - allTokens := t.generateAllTokens() + allTokens, err := t.generateAllTokens() + if err != nil { + // we were unable to generate required tokens, so we panic. + panic(err) + } uniqueTokens := make(Tokens, 0, requestedTokensCount) // allTokens is a sorted slice of tokens for instance t.cfg.InstanceID in zone t.cfg.zone @@ -214,11 +194,14 @@ func (t *SpreadMinimizingTokenGenerator) GenerateTokens(requestedTokensCount int // placed in the ring that already contains instances with all the ids lower that t.instanceID // is optimal. // Calls to this method will always return the same set of tokens. -func (t *SpreadMinimizingTokenGenerator) generateAllTokens() Tokens { - tokensByInstanceID := t.generateTokensByInstanceID() +func (t *SpreadMinimizingTokenGenerator) generateAllTokens() (Tokens, error) { + tokensByInstanceID, err := t.generateTokensByInstanceID() + if err != nil { + return nil, err + } allTokens := tokensByInstanceID[t.instanceID] slices.Sort(allTokens) - return allTokens + return allTokens, nil } // generateTokensByInstanceID generates the optimal number of tokens (optimalTokenPerInstance), @@ -226,13 +209,13 @@ func (t *SpreadMinimizingTokenGenerator) generateAllTokens() Tokens { // (with id t.instanceID). Generated tokens are not sorted, but they are distributed in such a // way that registered ownership of all the instances is optimal. // Calls to this method will always return the same set of tokens. -func (t *SpreadMinimizingTokenGenerator) generateTokensByInstanceID() map[int]Tokens { +func (t *SpreadMinimizingTokenGenerator) generateTokensByInstanceID() (map[int]Tokens, error) { firstInstanceTokens := t.generateFirstInstanceTokens() tokensByInstanceID := make(map[int]Tokens, t.instanceID+1) tokensByInstanceID[0] = firstInstanceTokens if t.instanceID == 0 { - return tokensByInstanceID + return tokensByInstanceID, nil } // tokensQueues is a slice of priority queues. Slice indexes correspond @@ -272,10 +255,8 @@ func (t *SpreadMinimizingTokenGenerator) generateTokensByInstanceID() map[int]To optimalTokenOwnership := t.optimalTokenOwnership(optimalInstanceOwnership, currInstanceOwnership, uint32(optimalTokensPerInstance-addedTokens)) highestOwnershipInstance := instanceQueue.Peek() if highestOwnershipInstance == nil || highestOwnershipInstance.ownership <= float64(optimalTokenOwnership) { - level.Warn(t.logger).Log("msg", "it was impossible to add a token because the instance with the highest ownership cannot satisfy the request", "added tokens", addedTokens+1, "highest ownership", highestOwnershipInstance.ownership, "requested ownership", optimalTokenOwnership) - // if this happens, it means that we cannot accommodate other tokens, so we panic - err := fmt.Errorf("it was impossible to add %dth token for instance with id %d in zone %s because the instance with the highest ownership cannot satisfy the requested ownership %d", addedTokens+1, i, t.spreadMinimizingZones[t.zoneID], optimalTokenOwnership) - panic(err) + // if this happens, it means that we cannot accommodate other tokens + return nil, fmt.Errorf("it was impossible to add %dth token for instance with id %d in zone id %d because the instance with the highest ownership cannot satisfy the requested ownership %d", addedTokens+1, i, t.zoneID, optimalTokenOwnership) } tokensQueue := tokensQueues[highestOwnershipInstance.item.instanceID] highestOwnershipToken := tokensQueue.Peek() @@ -288,10 +269,8 @@ func (t *SpreadMinimizingTokenGenerator) generateTokensByInstanceID() map[int]To token := highestOwnershipToken.item newToken, err := t.calculateNewToken(token, optimalTokenOwnership) if err != nil { - level.Error(t.logger).Log("msg", "it was impossible to calculate a new token because an error occurred", "err", err) - // if this happens, it means that we cannot accommodate additional tokens, so we panic - err := fmt.Errorf("it was impossible to calculate the %dth token for instance with id %d in zone %s", addedTokens+1, i, t.spreadMinimizingZones[t.zoneID]) - panic(err) + // if this happens, it means that we cannot accommodate additional tokens + return nil, fmt.Errorf("it was impossible to calculate the %dth token for instance with id %d in zone id %d", addedTokens+1, i, t.zoneID) } tokens = append(tokens, newToken) // add the new token to currInstanceTokenQueue @@ -317,7 +296,7 @@ func (t *SpreadMinimizingTokenGenerator) generateTokensByInstanceID() map[int]To tokensByInstanceID[i] = tokens // if this is the last iteration we return, so we avoid to call additional heap.Pushs if i == t.instanceID { - return tokensByInstanceID + return tokensByInstanceID, nil } // If there were some ignored instances, we put them back on the queue. @@ -331,7 +310,7 @@ func (t *SpreadMinimizingTokenGenerator) generateTokensByInstanceID() map[int]To heap.Push(&instanceQueue, newRingInstanceOwnershipInfo(i, currInstanceOwnership)) } - return tokensByInstanceID + return tokensByInstanceID, nil } func (t *SpreadMinimizingTokenGenerator) CanJoin(instances map[string]InstanceDesc) error { @@ -339,13 +318,10 @@ func (t *SpreadMinimizingTokenGenerator) CanJoin(instances map[string]InstanceDe return nil } - prevInstance, err := previousInstance(t.instance) - if err != nil { - if errors.Is(err, errorNoPreviousInstance) { - return nil - } - return err + if t.instanceID == 0 { + return nil } + prevInstance := fmt.Sprintf("%s%d", t.instancePrefix, t.instanceID-1) instanceDesc, ok := instances[prevInstance] if ok && len(instanceDesc.Tokens) != 0 { return nil diff --git a/vendor/github.com/grafana/dskit/ring/tokens.go b/vendor/github.com/grafana/dskit/ring/tokens.go index cf4999ff5d21..7f0780639421 100644 --- a/vendor/github.com/grafana/dskit/ring/tokens.go +++ b/vendor/github.com/grafana/dskit/ring/tokens.go @@ -7,6 +7,8 @@ import ( "sort" ) +type Token uint32 + // Tokens is a simple list of tokens. type Tokens []uint32 diff --git a/vendor/github.com/grafana/dskit/server/PROXYPROTOCOL.md b/vendor/github.com/grafana/dskit/server/PROXYPROTOCOL.md new file mode 100644 index 000000000000..726bde758dc8 --- /dev/null +++ b/vendor/github.com/grafana/dskit/server/PROXYPROTOCOL.md @@ -0,0 +1,28 @@ +# PROXY protocol support + +> **Note:** enabling PROXY protocol support does not break existing setups (e.g. non-PROXY connections are still accepted), however it does add a small overhead to the connection handling. + +To enable PROXY protocol support, set `Config.ProxyProtocolEnabled` to `true` before initializing a `Server` in your application. This enables PROXY protocol for both HTTP and gRPC servers. + +```go +cfg := &Config{ + ProxyProtocolEnabled: true, + // ... +} + +server := NewServer(cfg) +// ... +``` + +PROXY protocol is supported by using [go-proxyproto](https://github.com/pires/go-proxyproto). +Both PROXY v1 and PROXY v2 are supported out of the box. + +When enabled, incoming connections are checked for the PROXY header, and if present, the connection information is updated to reflect the original source address. +Most commonly, you will use the source address via [Request.RemoteAddr](https://pkg.go.dev/net/http#Request.RemoteAddr). + +```go +server.HTTP.HandleFunc("/your-endpoint", func(w http.ResponseWriter, r *http.Request) { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + // ... +}) +``` diff --git a/vendor/github.com/grafana/dskit/server/fake_server.pb.go b/vendor/github.com/grafana/dskit/server/fake_server.pb.go index 75ee6b0a14e3..4bb2d5a1f390 100644 --- a/vendor/github.com/grafana/dskit/server/fake_server.pb.go +++ b/vendor/github.com/grafana/dskit/server/fake_server.pb.go @@ -29,6 +29,49 @@ var _ = math.Inf // proto package needs to be updated. const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package +type ProxyProtoIPResponse struct { + IP string `protobuf:"bytes,1,opt,name=IP,proto3" json:"IP,omitempty"` +} + +func (m *ProxyProtoIPResponse) Reset() { *m = ProxyProtoIPResponse{} } +func (*ProxyProtoIPResponse) ProtoMessage() {} +func (*ProxyProtoIPResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_a932e7b7b9f5c118, []int{0} +} +func (m *ProxyProtoIPResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *ProxyProtoIPResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_ProxyProtoIPResponse.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *ProxyProtoIPResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ProxyProtoIPResponse.Merge(m, src) +} +func (m *ProxyProtoIPResponse) XXX_Size() int { + return m.Size() +} +func (m *ProxyProtoIPResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ProxyProtoIPResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ProxyProtoIPResponse proto.InternalMessageInfo + +func (m *ProxyProtoIPResponse) GetIP() string { + if m != nil { + return m.IP + } + return "" +} + type FailWithHTTPErrorRequest struct { Code int32 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"` } @@ -36,7 +79,7 @@ type FailWithHTTPErrorRequest struct { func (m *FailWithHTTPErrorRequest) Reset() { *m = FailWithHTTPErrorRequest{} } func (*FailWithHTTPErrorRequest) ProtoMessage() {} func (*FailWithHTTPErrorRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_a932e7b7b9f5c118, []int{0} + return fileDescriptor_a932e7b7b9f5c118, []int{1} } func (m *FailWithHTTPErrorRequest) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -73,32 +116,61 @@ func (m *FailWithHTTPErrorRequest) GetCode() int32 { } func init() { + proto.RegisterType((*ProxyProtoIPResponse)(nil), "server.ProxyProtoIPResponse") proto.RegisterType((*FailWithHTTPErrorRequest)(nil), "server.FailWithHTTPErrorRequest") } func init() { proto.RegisterFile("fake_server.proto", fileDescriptor_a932e7b7b9f5c118) } var fileDescriptor_a932e7b7b9f5c118 = []byte{ - // 265 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x4c, 0x4b, 0xcc, 0x4e, - 0x8d, 0x2f, 0x4e, 0x2d, 0x2a, 0x4b, 0x2d, 0xd2, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x83, - 0xf0, 0xa4, 0xa4, 0xd3, 0xf3, 0xf3, 0xd3, 0x73, 0x52, 0xf5, 0xc1, 0xa2, 0x49, 0xa5, 0x69, 0xfa, - 0xa9, 0xb9, 0x05, 0x25, 0x95, 0x10, 0x45, 0x4a, 0x7a, 0x5c, 0x12, 0x6e, 0x89, 0x99, 0x39, 0xe1, - 0x99, 0x25, 0x19, 0x1e, 0x21, 0x21, 0x01, 0xae, 0x45, 0x45, 0xf9, 0x45, 0x41, 0xa9, 0x85, 0xa5, - 0xa9, 0xc5, 0x25, 0x42, 0x42, 0x5c, 0x2c, 0xce, 0xf9, 0x29, 0xa9, 0x12, 0x8c, 0x0a, 0x8c, 0x1a, - 0xac, 0x41, 0x60, 0xb6, 0xd1, 0x6d, 0x26, 0x2e, 0x2e, 0xb7, 0xc4, 0xec, 0xd4, 0x60, 0xb0, 0xd9, - 0x42, 0xd6, 0x5c, 0xec, 0xc1, 0xa5, 0xc9, 0xc9, 0xa9, 0xa9, 0x29, 0x42, 0x62, 0x7a, 0x10, 0x7b, - 0xf4, 0x60, 0xf6, 0xe8, 0xb9, 0x82, 0xec, 0x91, 0xc2, 0x21, 0xae, 0xc4, 0x20, 0xe4, 0xc8, 0xc5, - 0x0b, 0xb3, 0x1b, 0x6c, 0x2f, 0x19, 0x46, 0xf8, 0x73, 0x09, 0x62, 0x38, 0x5f, 0x48, 0x41, 0x0f, - 0x1a, 0x0e, 0xb8, 0x7c, 0x86, 0xc7, 0x40, 0x4b, 0x2e, 0xd6, 0xe0, 0x9c, 0xd4, 0xd4, 0x02, 0xb2, - 0xbc, 0xc3, 0x1d, 0x5c, 0x52, 0x94, 0x9a, 0x98, 0x4b, 0xa6, 0x01, 0x06, 0x8c, 0x4e, 0x26, 0x17, - 0x1e, 0xca, 0x31, 0xdc, 0x78, 0x28, 0xc7, 0xf0, 0xe1, 0xa1, 0x1c, 0x63, 0xc3, 0x23, 0x39, 0xc6, - 0x15, 0x8f, 0xe4, 0x18, 0x4f, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, - 0xc6, 0x17, 0x8f, 0xe4, 0x18, 0x3e, 0x3c, 0x92, 0x63, 0x9c, 0xf0, 0x58, 0x8e, 0xe1, 0xc2, 0x63, - 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0x92, 0xd8, 0xc0, 0x26, 0x19, 0x03, 0x02, 0x00, 0x00, 0xff, - 0xff, 0x43, 0x2b, 0x71, 0x6d, 0x04, 0x02, 0x00, 0x00, -} + // 330 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x91, 0xb1, 0x4e, 0x02, 0x41, + 0x10, 0x86, 0x77, 0x51, 0x30, 0xae, 0xd1, 0x84, 0x8d, 0x31, 0x04, 0xcd, 0x84, 0x5c, 0x61, 0xac, + 0x0e, 0xa3, 0x36, 0xc6, 0x4a, 0x09, 0xc4, 0xab, 0xdc, 0xdc, 0x91, 0x58, 0x9a, 0x03, 0x06, 0x24, + 0x1c, 0xec, 0xb9, 0x77, 0x67, 0xa4, 0xf3, 0x11, 0x7c, 0x0c, 0x3b, 0x5f, 0xc3, 0x92, 0x92, 0x52, + 0x96, 0xc6, 0x92, 0x47, 0x30, 0x2c, 0x12, 0x0b, 0xc5, 0xe2, 0xba, 0x9d, 0xc9, 0xe4, 0xff, 0xbf, + 0x7f, 0x7f, 0x96, 0x6f, 0xfb, 0x3d, 0xbc, 0x8b, 0x50, 0x3d, 0xa2, 0xb2, 0x43, 0x25, 0x63, 0xc9, + 0x73, 0x8b, 0xa9, 0xb8, 0xdf, 0x91, 0xb2, 0x13, 0x60, 0xd9, 0x6c, 0x1b, 0x49, 0xbb, 0x8c, 0xfd, + 0x30, 0x1e, 0x2e, 0x8e, 0xac, 0x43, 0xb6, 0x2b, 0x94, 0x7c, 0x1a, 0x8a, 0xf9, 0xe4, 0x08, 0x17, + 0xa3, 0x50, 0x0e, 0x22, 0xe4, 0x3b, 0x2c, 0xe3, 0x88, 0x02, 0x2d, 0xd1, 0xa3, 0x4d, 0x37, 0xe3, + 0x08, 0xcb, 0x66, 0x85, 0x9a, 0xdf, 0x0d, 0x6e, 0xbb, 0xf1, 0xfd, 0x75, 0xbd, 0x2e, 0xaa, 0x4a, + 0x49, 0xe5, 0xe2, 0x43, 0x82, 0x51, 0xcc, 0x39, 0x5b, 0xaf, 0xc8, 0x16, 0x9a, 0xeb, 0xac, 0x6b, + 0xde, 0x27, 0x6f, 0x6b, 0x8c, 0xd5, 0xfc, 0x1e, 0x7a, 0x86, 0x81, 0x5f, 0xb0, 0x0d, 0x2f, 0x69, + 0x36, 0x11, 0x5b, 0x7c, 0xcf, 0x5e, 0xf0, 0xd8, 0x4b, 0x1e, 0xbb, 0x3a, 0xe7, 0x29, 0xae, 0xd8, + 0x5b, 0x84, 0x5f, 0xb2, 0xed, 0xa5, 0xb7, 0xf1, 0x4d, 0x21, 0x71, 0xc3, 0xf2, 0xbf, 0xf0, 0x79, + 0xc9, 0xfe, 0xfe, 0xaf, 0x55, 0xc9, 0xfe, 0x11, 0x3c, 0x67, 0x59, 0x2f, 0x40, 0x0c, 0x53, 0xc5, + 0xd9, 0xf2, 0x62, 0x85, 0x7e, 0x3f, 0xa5, 0xc0, 0x31, 0xe5, 0x2e, 0x2b, 0xb8, 0x18, 0x27, 0x6a, + 0xf0, 0xd3, 0x5d, 0xc5, 0x0f, 0x02, 0x54, 0x8e, 0x58, 0xa9, 0x77, 0xb0, 0x4c, 0xfb, 0x57, 0xdf, + 0x16, 0xb9, 0x3a, 0x1b, 0x4d, 0x80, 0x8c, 0x27, 0x40, 0x66, 0x13, 0xa0, 0xcf, 0x1a, 0xe8, 0xab, + 0x06, 0xfa, 0xae, 0x81, 0x8e, 0x34, 0xd0, 0x0f, 0x0d, 0xf4, 0x53, 0x03, 0x99, 0x69, 0xa0, 0x2f, + 0x53, 0x20, 0xa3, 0x29, 0x90, 0xf1, 0x14, 0x48, 0x23, 0x67, 0x5c, 0x4e, 0xbf, 0x02, 0x00, 0x00, + 0xff, 0xff, 0xf3, 0x3d, 0xce, 0x89, 0x80, 0x02, 0x00, 0x00, +} + +func (this *ProxyProtoIPResponse) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + that1, ok := that.(*ProxyProtoIPResponse) + if !ok { + that2, ok := that.(ProxyProtoIPResponse) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if this.IP != that1.IP { + return false + } + return true +} func (this *FailWithHTTPErrorRequest) Equal(that interface{}) bool { if that == nil { return this == nil @@ -123,6 +195,16 @@ func (this *FailWithHTTPErrorRequest) Equal(that interface{}) bool { } return true } +func (this *ProxyProtoIPResponse) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 5) + s = append(s, "&server.ProxyProtoIPResponse{") + s = append(s, "IP: "+fmt.Sprintf("%#v", this.IP)+",\n") + s = append(s, "}") + return strings.Join(s, "") +} func (this *FailWithHTTPErrorRequest) GoString() string { if this == nil { return "nil" @@ -159,6 +241,7 @@ type FakeServerClient interface { FailWithHTTPError(ctx context.Context, in *FailWithHTTPErrorRequest, opts ...grpc.CallOption) (*empty.Empty, error) Sleep(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*empty.Empty, error) StreamSleep(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (FakeServer_StreamSleepClient, error) + ReturnProxyProtoCallerIP(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*ProxyProtoIPResponse, error) } type fakeServerClient struct { @@ -237,6 +320,15 @@ func (x *fakeServerStreamSleepClient) Recv() (*empty.Empty, error) { return m, nil } +func (c *fakeServerClient) ReturnProxyProtoCallerIP(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*ProxyProtoIPResponse, error) { + out := new(ProxyProtoIPResponse) + err := c.cc.Invoke(ctx, "/server.FakeServer/ReturnProxyProtoCallerIP", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // FakeServerServer is the server API for FakeServer service. type FakeServerServer interface { Succeed(context.Context, *empty.Empty) (*empty.Empty, error) @@ -244,6 +336,7 @@ type FakeServerServer interface { FailWithHTTPError(context.Context, *FailWithHTTPErrorRequest) (*empty.Empty, error) Sleep(context.Context, *empty.Empty) (*empty.Empty, error) StreamSleep(*empty.Empty, FakeServer_StreamSleepServer) error + ReturnProxyProtoCallerIP(context.Context, *empty.Empty) (*ProxyProtoIPResponse, error) } // UnimplementedFakeServerServer can be embedded to have forward compatible implementations. @@ -265,6 +358,9 @@ func (*UnimplementedFakeServerServer) Sleep(ctx context.Context, req *empty.Empt func (*UnimplementedFakeServerServer) StreamSleep(req *empty.Empty, srv FakeServer_StreamSleepServer) error { return status.Errorf(codes.Unimplemented, "method StreamSleep not implemented") } +func (*UnimplementedFakeServerServer) ReturnProxyProtoCallerIP(ctx context.Context, req *empty.Empty) (*ProxyProtoIPResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ReturnProxyProtoCallerIP not implemented") +} func RegisterFakeServerServer(s *grpc.Server, srv FakeServerServer) { s.RegisterService(&_FakeServer_serviceDesc, srv) @@ -363,6 +459,24 @@ func (x *fakeServerStreamSleepServer) Send(m *empty.Empty) error { return x.ServerStream.SendMsg(m) } +func _FakeServer_ReturnProxyProtoCallerIP_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(empty.Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(FakeServerServer).ReturnProxyProtoCallerIP(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/server.FakeServer/ReturnProxyProtoCallerIP", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(FakeServerServer).ReturnProxyProtoCallerIP(ctx, req.(*empty.Empty)) + } + return interceptor(ctx, in, info, handler) +} + var _FakeServer_serviceDesc = grpc.ServiceDesc{ ServiceName: "server.FakeServer", HandlerType: (*FakeServerServer)(nil), @@ -383,6 +497,10 @@ var _FakeServer_serviceDesc = grpc.ServiceDesc{ MethodName: "Sleep", Handler: _FakeServer_Sleep_Handler, }, + { + MethodName: "ReturnProxyProtoCallerIP", + Handler: _FakeServer_ReturnProxyProtoCallerIP_Handler, + }, }, Streams: []grpc.StreamDesc{ { @@ -394,6 +512,36 @@ var _FakeServer_serviceDesc = grpc.ServiceDesc{ Metadata: "fake_server.proto", } +func (m *ProxyProtoIPResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ProxyProtoIPResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *ProxyProtoIPResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.IP) > 0 { + i -= len(m.IP) + copy(dAtA[i:], m.IP) + i = encodeVarintFakeServer(dAtA, i, uint64(len(m.IP))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + func (m *FailWithHTTPErrorRequest) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) @@ -433,6 +581,19 @@ func encodeVarintFakeServer(dAtA []byte, offset int, v uint64) int { dAtA[offset] = uint8(v) return base } +func (m *ProxyProtoIPResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.IP) + if l > 0 { + n += 1 + l + sovFakeServer(uint64(l)) + } + return n +} + func (m *FailWithHTTPErrorRequest) Size() (n int) { if m == nil { return 0 @@ -451,6 +612,16 @@ func sovFakeServer(x uint64) (n int) { func sozFakeServer(x uint64) (n int) { return sovFakeServer(uint64((x << 1) ^ uint64((int64(x) >> 63)))) } +func (this *ProxyProtoIPResponse) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&ProxyProtoIPResponse{`, + `IP:` + fmt.Sprintf("%v", this.IP) + `,`, + `}`, + }, "") + return s +} func (this *FailWithHTTPErrorRequest) String() string { if this == nil { return "nil" @@ -469,6 +640,91 @@ func valueToStringFakeServer(v interface{}) string { pv := reflect.Indirect(rv).Interface() return fmt.Sprintf("*%v", pv) } +func (m *ProxyProtoIPResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowFakeServer + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ProxyProtoIPResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ProxyProtoIPResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field IP", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowFakeServer + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthFakeServer + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthFakeServer + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.IP = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipFakeServer(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthFakeServer + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthFakeServer + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func (m *FailWithHTTPErrorRequest) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 diff --git a/vendor/github.com/grafana/dskit/server/fake_server.proto b/vendor/github.com/grafana/dskit/server/fake_server.proto index 248a6f244bda..0c4780cda0d4 100644 --- a/vendor/github.com/grafana/dskit/server/fake_server.proto +++ b/vendor/github.com/grafana/dskit/server/fake_server.proto @@ -10,6 +10,11 @@ service FakeServer { rpc FailWithHTTPError(FailWithHTTPErrorRequest) returns (google.protobuf.Empty) {}; rpc Sleep(google.protobuf.Empty) returns (google.protobuf.Empty) {}; rpc StreamSleep(google.protobuf.Empty) returns (stream google.protobuf.Empty) {}; + rpc ReturnProxyProtoCallerIP(google.protobuf.Empty) returns (ProxyProtoIPResponse) {}; +} + +message ProxyProtoIPResponse { + string IP = 1; } message FailWithHTTPErrorRequest { diff --git a/vendor/github.com/grafana/dskit/server/metrics.go b/vendor/github.com/grafana/dskit/server/metrics.go index aa1c3e53aef9..d6011525da3a 100644 --- a/vendor/github.com/grafana/dskit/server/metrics.go +++ b/vendor/github.com/grafana/dskit/server/metrics.go @@ -15,12 +15,13 @@ import ( ) type Metrics struct { - TCPConnections *prometheus.GaugeVec - TCPConnectionsLimit *prometheus.GaugeVec - RequestDuration *prometheus.HistogramVec - ReceivedMessageSize *prometheus.HistogramVec - SentMessageSize *prometheus.HistogramVec - InflightRequests *prometheus.GaugeVec + TCPConnections *prometheus.GaugeVec + TCPConnectionsLimit *prometheus.GaugeVec + RequestDuration *prometheus.HistogramVec + PerTenantRequestDuration *prometheus.HistogramVec + ReceivedMessageSize *prometheus.HistogramVec + SentMessageSize *prometheus.HistogramVec + InflightRequests *prometheus.GaugeVec } func NewServerMetrics(cfg Config) *Metrics { @@ -46,6 +47,15 @@ func NewServerMetrics(cfg Config) *Metrics { NativeHistogramMaxBucketNumber: 100, NativeHistogramMinResetDuration: time.Hour, }, []string{"method", "route", "status_code", "ws"}), + PerTenantRequestDuration: reg.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: cfg.MetricsNamespace, + Name: "per_tenant_request_duration_seconds", + Help: "Time (in seconds) spent serving HTTP requests for a particular tenant.", + Buckets: instrument.DefBuckets, + NativeHistogramBucketFactor: cfg.MetricsNativeHistogramFactor, + NativeHistogramMaxBucketNumber: 100, + NativeHistogramMinResetDuration: time.Hour, + }, []string{"method", "route", "status_code", "ws", "tenant"}), ReceivedMessageSize: reg.NewHistogramVec(prometheus.HistogramOpts{ Namespace: cfg.MetricsNamespace, Name: "request_message_bytes", diff --git a/vendor/github.com/grafana/dskit/server/server.go b/vendor/github.com/grafana/dskit/server/server.go index 6c2133a9bc24..a23eead3891e 100644 --- a/vendor/github.com/grafana/dskit/server/server.go +++ b/vendor/github.com/grafana/dskit/server/server.go @@ -17,21 +17,21 @@ import ( "strings" "time" - _ "github.com/grafana/pyroscope-go/godeltaprof/http/pprof" // anonymous import to get godelatprof handlers registered - gokit_log "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/gorilla/mux" + _ "github.com/grafana/pyroscope-go/godeltaprof/http/pprof" // anonymous import to get godelatprof handlers registered otgrpc "github.com/opentracing-contrib/go-grpc" "github.com/opentracing/opentracing-go" + "github.com/pires/go-proxyproto" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/common/config" "github.com/prometheus/exporter-toolkit/web" - "github.com/soheilhy/cmux" "golang.org/x/net/netutil" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/experimental" "google.golang.org/grpc/keepalive" "github.com/grafana/dskit/httpgrpc" @@ -80,14 +80,15 @@ type Config struct { // for details. A generally useful value is 1.1. MetricsNativeHistogramFactor float64 `yaml:"-"` - HTTPListenNetwork string `yaml:"http_listen_network"` - HTTPListenAddress string `yaml:"http_listen_address"` - HTTPListenPort int `yaml:"http_listen_port"` - HTTPConnLimit int `yaml:"http_listen_conn_limit"` - GRPCListenNetwork string `yaml:"grpc_listen_network"` - GRPCListenAddress string `yaml:"grpc_listen_address"` - GRPCListenPort int `yaml:"grpc_listen_port"` - GRPCConnLimit int `yaml:"grpc_listen_conn_limit"` + HTTPListenNetwork string `yaml:"http_listen_network"` + HTTPListenAddress string `yaml:"http_listen_address"` + HTTPListenPort int `yaml:"http_listen_port"` + HTTPConnLimit int `yaml:"http_listen_conn_limit"` + GRPCListenNetwork string `yaml:"grpc_listen_network"` + GRPCListenAddress string `yaml:"grpc_listen_address"` + GRPCListenPort int `yaml:"grpc_listen_port"` + GRPCConnLimit int `yaml:"grpc_listen_conn_limit"` + ProxyProtocolEnabled bool `yaml:"proxy_protocol_enabled"` CipherSuites string `yaml:"tls_cipher_suites"` MinVersion string `yaml:"tls_min_version"` @@ -100,6 +101,8 @@ type Config struct { ExcludeRequestInLog bool `yaml:"-"` DisableRequestSuccessLog bool `yaml:"-"` + PerTenantDurationInstrumentation middleware.PerTenantCallback `yaml:"-"` + ServerGracefulShutdownTimeout time.Duration `yaml:"graceful_shutdown_timeout"` HTTPServerReadTimeout time.Duration `yaml:"http_server_read_timeout"` HTTPServerReadHeaderTimeout time.Duration `yaml:"http_server_read_header_timeout"` @@ -114,7 +117,6 @@ type Config struct { HTTPMiddleware []middleware.Interface `yaml:"-"` Router *mux.Router `yaml:"-"` DoNotAddDefaultHTTPMiddleware bool `yaml:"-"` - RouteHTTPToGRPC bool `yaml:"-"` GRPCServerMaxRecvMsgSize int `yaml:"grpc_server_max_recv_msg_size"` GRPCServerMaxSendMsgSize int `yaml:"grpc_server_max_send_msg_size"` @@ -127,11 +129,14 @@ type Config struct { GRPCServerMinTimeBetweenPings time.Duration `yaml:"grpc_server_min_time_between_pings"` GRPCServerPingWithoutStreamAllowed bool `yaml:"grpc_server_ping_without_stream_allowed"` GRPCServerNumWorkers int `yaml:"grpc_server_num_workers"` + GRPCServerStatsTrackingEnabled bool `yaml:"grpc_server_stats_tracking_enabled"` + GRPCServerRecvBufferPoolsEnabled bool `yaml:"grpc_server_recv_buffer_pools_enabled"` LogFormat string `yaml:"log_format"` LogLevel log.Level `yaml:"log_level"` Log gokit_log.Logger `yaml:"-"` LogSourceIPs bool `yaml:"log_source_ips_enabled"` + LogSourceIPsFull bool `yaml:"log_source_ips_full"` LogSourceIPsHeader string `yaml:"log_source_ips_header"` LogSourceIPsRegex string `yaml:"log_source_ips_regex"` LogRequestHeaders bool `yaml:"log_request_headers"` @@ -191,16 +196,20 @@ func (cfg *Config) RegisterFlags(f *flag.FlagSet) { f.DurationVar(&cfg.GRPCServerTimeout, "server.grpc.keepalive.timeout", time.Second*20, "After having pinged for keepalive check, the duration after which an idle connection should be closed, Default: 20s") f.DurationVar(&cfg.GRPCServerMinTimeBetweenPings, "server.grpc.keepalive.min-time-between-pings", 5*time.Minute, "Minimum amount of time a client should wait before sending a keepalive ping. If client sends keepalive ping more often, server will send GOAWAY and close the connection.") f.BoolVar(&cfg.GRPCServerPingWithoutStreamAllowed, "server.grpc.keepalive.ping-without-stream-allowed", false, "If true, server allows keepalive pings even when there are no active streams(RPCs). If false, and client sends ping when there are no active streams, server will send GOAWAY and close the connection.") + f.BoolVar(&cfg.GRPCServerStatsTrackingEnabled, "server.grpc.stats-tracking-enabled", true, "If true, the request_message_bytes, response_message_bytes, and inflight_requests metrics will be tracked. Enabling this option prevents the use of memory pools for parsing gRPC request bodies and may lead to more memory allocations.") + f.BoolVar(&cfg.GRPCServerRecvBufferPoolsEnabled, "server.grpc.recv-buffer-pools-enabled", false, "If true, gGPC's buffer pools will be used to handle incoming requests. Enabling this feature can reduce memory allocation, but also requires disabling GRPC server stats tracking by setting `server.grpc.stats-tracking-enabled=false`. This is an experimental gRPC feature, so it might be removed in a future version of the gRPC library.") f.IntVar(&cfg.GRPCServerNumWorkers, "server.grpc.num-workers", 0, "If non-zero, configures the amount of GRPC server workers used to serve the requests.") f.StringVar(&cfg.PathPrefix, "server.path-prefix", "", "Base path to serve all API routes from (e.g. /v1/)") f.StringVar(&cfg.LogFormat, "log.format", log.LogfmtFormat, "Output log messages in the given format. Valid formats: [logfmt, json]") cfg.LogLevel.RegisterFlags(f) f.BoolVar(&cfg.LogSourceIPs, "server.log-source-ips-enabled", false, "Optionally log the source IPs.") + f.BoolVar(&cfg.LogSourceIPsFull, "server.log-source-ips-full", false, "Log all source IPs instead of only the originating one. Only used if server.log-source-ips-enabled is true") f.StringVar(&cfg.LogSourceIPsHeader, "server.log-source-ips-header", "", "Header field storing the source IPs. Only used if server.log-source-ips-enabled is true. If not set the default Forwarded, X-Real-IP and X-Forwarded-For headers are used") f.StringVar(&cfg.LogSourceIPsRegex, "server.log-source-ips-regex", "", "Regex for matching the source IPs. Only used if server.log-source-ips-enabled is true. If not set the default Forwarded, X-Real-IP and X-Forwarded-For headers are used") f.BoolVar(&cfg.LogRequestHeaders, "server.log-request-headers", false, "Optionally log request headers.") f.StringVar(&cfg.LogRequestExcludeHeadersList, "server.log-request-headers-exclude-list", "", "Comma separated list of headers to exclude from loggin. Only used if server.log-request-headers is true.") f.BoolVar(&cfg.LogRequestAtInfoLevel, "server.log-request-at-info-level-enabled", false, "Optionally log requests at info level instead of debug level. Applies to request headers as well if server.log-request-headers is enabled.") + f.BoolVar(&cfg.ProxyProtocolEnabled, "server.proxy-protocol-enabled", false, "Enables PROXY protocol.") } func (cfg *Config) registererOrDefault() prometheus.Registerer { @@ -220,13 +229,6 @@ type Server struct { grpcListener net.Listener httpListener net.Listener - // These fields are used to support grpc over the http server - // if RouteHTTPToGRPC is set. the fields are kept here - // so they can be initialized in New() and started in Run() - grpchttpmux cmux.CMux - grpcOnHTTPListener net.Listener - GRPCOnHTTPServer *grpc.Server - HTTP *mux.Router HTTPServer *http.Server GRPC *grpc.Server @@ -278,15 +280,6 @@ func newServer(cfg Config, metrics *Metrics) (*Server, error) { httpListener = netutil.LimitListener(httpListener, cfg.HTTPConnLimit) } - var grpcOnHTTPListener net.Listener - var grpchttpmux cmux.CMux - if cfg.RouteHTTPToGRPC { - grpchttpmux = cmux.New(httpListener) - - httpListener = grpchttpmux.Match(cmux.HTTP1Fast("PATCH")) - grpcOnHTTPListener = grpchttpmux.Match(cmux.HTTP2()) - } - network = cfg.GRPCListenNetwork if network == "" { network = DefaultNetwork @@ -302,6 +295,11 @@ func newServer(cfg Config, metrics *Metrics) (*Server, error) { grpcListener = netutil.LimitListener(grpcListener, cfg.GRPCConnLimit) } + if cfg.ProxyProtocolEnabled { + httpListener = newProxyProtocolListener(httpListener, cfg.HTTPServerReadHeaderTimeout) + grpcListener = newProxyProtocolListener(grpcListener, cfg.HTTPServerReadHeaderTimeout) + } + cipherSuites, err := stringToCipherSuites(cfg.CipherSuites) if err != nil { return nil, err @@ -375,22 +373,29 @@ func newServer(cfg Config, metrics *Metrics) (*Server, error) { WithRequest: !cfg.ExcludeRequestInLog, DisableRequestSuccessLog: cfg.DisableRequestSuccessLog, } - var reportGRPCStatusesOptions []middleware.InstrumentationOption + var grpcInstrumentationOptions []middleware.InstrumentationOption if cfg.ReportGRPCCodesInInstrumentationLabel { - reportGRPCStatusesOptions = []middleware.InstrumentationOption{middleware.ReportGRPCStatusOption} + grpcInstrumentationOptions = append(grpcInstrumentationOptions, middleware.ReportGRPCStatusOption) + } + if cfg.PerTenantDurationInstrumentation != nil { + grpcInstrumentationOptions = append(grpcInstrumentationOptions, + middleware.WithPerTenantInstrumentation( + metrics.PerTenantRequestDuration, + cfg.PerTenantDurationInstrumentation, + )) } grpcMiddleware := []grpc.UnaryServerInterceptor{ serverLog.UnaryServerInterceptor, otgrpc.OpenTracingServerInterceptor(opentracing.GlobalTracer()), middleware.HTTPGRPCTracingInterceptor(router), // This must appear after the OpenTracingServerInterceptor. - middleware.UnaryServerInstrumentInterceptor(metrics.RequestDuration, reportGRPCStatusesOptions...), + middleware.UnaryServerInstrumentInterceptor(metrics.RequestDuration, grpcInstrumentationOptions...), } grpcMiddleware = append(grpcMiddleware, cfg.GRPCMiddleware...) grpcStreamMiddleware := []grpc.StreamServerInterceptor{ serverLog.StreamServerInterceptor, otgrpc.OpenTracingStreamServerInterceptor(opentracing.GlobalTracer()), - middleware.StreamServerInstrumentInterceptor(metrics.RequestDuration, reportGRPCStatusesOptions...), + middleware.StreamServerInstrumentInterceptor(metrics.RequestDuration, grpcInstrumentationOptions...), } grpcStreamMiddleware = append(grpcStreamMiddleware, cfg.GRPCStreamMiddleware...) @@ -423,13 +428,22 @@ func newServer(cfg Config, metrics *Metrics) (*Server, error) { grpcOptions = append(grpcOptions, grpc.InTapHandle(grpcServerLimit.TapHandle), grpc.StatsHandler(grpcServerLimit)) } - grpcOptions = append(grpcOptions, - grpc.StatsHandler(middleware.NewStatsHandler( - metrics.ReceivedMessageSize, - metrics.SentMessageSize, - metrics.InflightRequests, - )), - ) + if cfg.GRPCServerStatsTrackingEnabled { + grpcOptions = append(grpcOptions, + grpc.StatsHandler(middleware.NewStatsHandler( + metrics.ReceivedMessageSize, + metrics.SentMessageSize, + metrics.InflightRequests, + )), + ) + } + + if cfg.GRPCServerRecvBufferPoolsEnabled { + if cfg.GRPCServerStatsTrackingEnabled { + return nil, fmt.Errorf("grpc_server_stats_tracking_enabled must be set to false if grpc_server_recv_buffer_pools_enabled is true") + } + grpcOptions = append(grpcOptions, experimental.RecvBufferPool(grpc.NewSharedBufferPool())) + } grpcOptions = append(grpcOptions, cfg.GRPCOptions...) if grpcTLSConfig != nil { @@ -437,41 +451,10 @@ func newServer(cfg Config, metrics *Metrics) (*Server, error) { grpcOptions = append(grpcOptions, grpc.Creds(grpcCreds)) } grpcServer := grpc.NewServer(grpcOptions...) - grpcOnHTTPServer := grpc.NewServer(grpcOptions...) - sourceIPs, err := middleware.NewSourceIPs(cfg.LogSourceIPsHeader, cfg.LogSourceIPsRegex) + httpMiddleware, err := BuildHTTPMiddleware(cfg, router, metrics, logger) if err != nil { - return nil, fmt.Errorf("error setting up source IP extraction: %v", err) - } - logSourceIPs := sourceIPs - if !cfg.LogSourceIPs { - // We always include the source IPs for traces, - // but only want to log them in the middleware if that is enabled. - logSourceIPs = nil - } - - defaultLogMiddleware := middleware.NewLogMiddleware(logger, cfg.LogRequestHeaders, cfg.LogRequestAtInfoLevel, logSourceIPs, strings.Split(cfg.LogRequestExcludeHeadersList, ",")) - defaultLogMiddleware.DisableRequestSuccessLog = cfg.DisableRequestSuccessLog - - defaultHTTPMiddleware := []middleware.Interface{ - middleware.Tracer{ - RouteMatcher: router, - SourceIPs: sourceIPs, - }, - defaultLogMiddleware, - middleware.Instrument{ - RouteMatcher: router, - Duration: metrics.RequestDuration, - RequestBodySize: metrics.ReceivedMessageSize, - ResponseBodySize: metrics.SentMessageSize, - InflightRequests: metrics.InflightRequests, - }, - } - var httpMiddleware []middleware.Interface - if cfg.DoNotAddDefaultHTTPMiddleware { - httpMiddleware = cfg.HTTPMiddleware - } else { - httpMiddleware = append(defaultHTTPMiddleware, cfg.HTTPMiddleware...) + return nil, fmt.Errorf("error building http middleware: %w", err) } httpServer := &http.Server{ @@ -491,20 +474,17 @@ func newServer(cfg Config, metrics *Metrics) (*Server, error) { } return &Server{ - cfg: cfg, - httpListener: httpListener, - grpcListener: grpcListener, - grpcOnHTTPListener: grpcOnHTTPListener, - handler: handler, - grpchttpmux: grpchttpmux, - - HTTP: router, - HTTPServer: httpServer, - GRPC: grpcServer, - GRPCOnHTTPServer: grpcOnHTTPServer, - Log: logger, - Registerer: cfg.registererOrDefault(), - Gatherer: gatherer, + cfg: cfg, + httpListener: httpListener, + grpcListener: grpcListener, + handler: handler, + + HTTP: router, + HTTPServer: httpServer, + GRPC: grpcServer, + Log: logger, + Registerer: cfg.registererOrDefault(), + Gatherer: gatherer, }, nil } @@ -521,6 +501,48 @@ func RegisterInstrumentationWithGatherer(router *mux.Router, gatherer prometheus router.PathPrefix("/debug/pprof").Handler(http.DefaultServeMux) } +func BuildHTTPMiddleware(cfg Config, router *mux.Router, metrics *Metrics, logger gokit_log.Logger) ([]middleware.Interface, error) { + sourceIPs, err := middleware.NewSourceIPs(cfg.LogSourceIPsHeader, cfg.LogSourceIPsRegex, cfg.LogSourceIPsFull) + if err != nil { + return nil, fmt.Errorf("error setting up source IP extraction: %w", err) + } + logSourceIPs := sourceIPs + if !cfg.LogSourceIPs { + // We always include the source IPs for traces, + // but only want to log them in the middleware if that is enabled. + logSourceIPs = nil + } + + defaultLogMiddleware := middleware.NewLogMiddleware(logger, cfg.LogRequestHeaders, cfg.LogRequestAtInfoLevel, logSourceIPs, strings.Split(cfg.LogRequestExcludeHeadersList, ",")) + defaultLogMiddleware.DisableRequestSuccessLog = cfg.DisableRequestSuccessLog + + defaultHTTPMiddleware := []middleware.Interface{ + middleware.RouteInjector{ + RouteMatcher: router, + }, + middleware.Tracer{ + SourceIPs: sourceIPs, + }, + defaultLogMiddleware, + middleware.Instrument{ + Duration: metrics.RequestDuration, + PerTenantDuration: metrics.PerTenantRequestDuration, + PerTenantCallback: cfg.PerTenantDurationInstrumentation, + RequestBodySize: metrics.ReceivedMessageSize, + ResponseBodySize: metrics.SentMessageSize, + InflightRequests: metrics.InflightRequests, + }, + } + var httpMiddleware []middleware.Interface + if cfg.DoNotAddDefaultHTTPMiddleware { + httpMiddleware = cfg.HTTPMiddleware + } else { + httpMiddleware = append(defaultHTTPMiddleware, cfg.HTTPMiddleware...) + } + + return httpMiddleware, nil +} + // Run the server; blocks until SIGTERM (if signal handling is enabled), an error is received, or Stop() is called. func (s *Server) Run() error { errChan := make(chan error, 1) @@ -563,18 +585,6 @@ func (s *Server) Run() error { handleGRPCError(err, errChan) }() - // grpchttpmux will only be set if grpchttpmux RouteHTTPToGRPC is set - if s.grpchttpmux != nil { - go func() { - err := s.grpchttpmux.Serve() - handleGRPCError(err, errChan) - }() - go func() { - err := s.GRPCOnHTTPServer.Serve(s.grpcOnHTTPListener) - handleGRPCError(err, errChan) - }() - } - return <-errChan } @@ -615,3 +625,13 @@ func (s *Server) Shutdown() { _ = s.HTTPServer.Shutdown(ctx) s.GRPC.GracefulStop() } + +func newProxyProtocolListener(httpListener net.Listener, readHeaderTimeout time.Duration) net.Listener { + // Wraps the listener with a proxy protocol listener. + // NOTE: go-proxyproto supports non-PROXY, PROXY v1 and PROXY v2 protocols via the same listener. + // Therefore, enabling this feature does not break existing setups. + return &proxyproto.Listener{ + Listener: httpListener, + ReadHeaderTimeout: readHeaderTimeout, + } +} diff --git a/vendor/github.com/grafana/dskit/spanlogger/spanlogger.go b/vendor/github.com/grafana/dskit/spanlogger/spanlogger.go index 08653eda38ab..70c86d16d85d 100644 --- a/vendor/github.com/grafana/dskit/spanlogger/spanlogger.go +++ b/vendor/github.com/grafana/dskit/spanlogger/spanlogger.go @@ -158,7 +158,7 @@ func (s *SpanLogger) getLogger() log.Logger { traceID, ok := tracing.ExtractSampledTraceID(s.ctx) if ok { - logger = log.With(logger, "traceID", traceID) + logger = log.With(logger, "trace_id", traceID) } // If the value has been set by another goroutine, fetch that other value and discard the one we made. if !s.logger.CompareAndSwap(nil, &logger) { @@ -167,3 +167,17 @@ func (s *SpanLogger) getLogger() log.Logger { } return logger } + +// SetSpanAndLogTag sets a tag on the span used by this SpanLogger, and appends a key/value pair to the logger used for +// future log lines emitted by this SpanLogger. +// +// It is not safe to call this method from multiple goroutines simultaneously. +// It is safe to call this method at the same time as calling other SpanLogger methods, however, this may produce +// inconsistent results (eg. some log lines may be emitted with the provided key/value pair, and others may not). +func (s *SpanLogger) SetSpanAndLogTag(key string, value interface{}) { + s.Span.SetTag(key, value) + + logger := s.getLogger() + wrappedLogger := log.With(logger, key, value) + s.logger.Store(&wrappedLogger) +} diff --git a/vendor/github.com/grafana/dskit/user/grpc.go b/vendor/github.com/grafana/dskit/user/grpc.go index 201b835eeab7..fcfd3d7a91cd 100644 --- a/vendor/github.com/grafana/dskit/user/grpc.go +++ b/vendor/github.com/grafana/dskit/user/grpc.go @@ -13,13 +13,8 @@ import ( // ExtractFromGRPCRequest extracts the user ID from the request metadata and returns // the user ID and a context with the user ID injected. func ExtractFromGRPCRequest(ctx context.Context) (string, context.Context, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return "", ctx, ErrNoOrgID - } - - orgIDs, ok := md[lowerOrgIDHeaderName] - if !ok || len(orgIDs) != 1 { + orgIDs := metadata.ValueFromIncomingContext(ctx, lowerOrgIDHeaderName) + if len(orgIDs) != 1 { return "", ctx, ErrNoOrgID } diff --git a/vendor/github.com/grafana/gomemcache/memcache/memcache.go b/vendor/github.com/grafana/gomemcache/memcache/memcache.go index c5962d092e0f..c627cbdf9834 100644 --- a/vendor/github.com/grafana/gomemcache/memcache/memcache.go +++ b/vendor/github.com/grafana/gomemcache/memcache/memcache.go @@ -619,7 +619,7 @@ func (c *Client) GetMulti(keys []string, opts ...Option) (map[string]*Item, erro options := newOptions(opts...) var lk sync.Mutex - m := make(map[string]*Item) + m := make(map[string]*Item, len(keys)) addItemToMap := func(it *Item) { lk.Lock() defer lk.Unlock() diff --git a/vendor/github.com/pires/go-proxyproto/.gitignore b/vendor/github.com/pires/go-proxyproto/.gitignore new file mode 100644 index 000000000000..a2d2c3019769 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/.gitignore @@ -0,0 +1,11 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +.idea +bin +pkg + +*.out diff --git a/vendor/github.com/soheilhy/cmux/LICENSE b/vendor/github.com/pires/go-proxyproto/LICENSE similarity index 99% rename from vendor/github.com/soheilhy/cmux/LICENSE rename to vendor/github.com/pires/go-proxyproto/LICENSE index d64569567334..a65c05a62717 100644 --- a/vendor/github.com/soheilhy/cmux/LICENSE +++ b/vendor/github.com/pires/go-proxyproto/LICENSE @@ -1,4 +1,3 @@ - Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ @@ -179,7 +178,7 @@ APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" + boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a @@ -187,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2016 Paulo Pires Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/vendor/github.com/pires/go-proxyproto/README.md b/vendor/github.com/pires/go-proxyproto/README.md new file mode 100644 index 000000000000..982707cceef8 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/README.md @@ -0,0 +1,162 @@ +# go-proxyproto + +[![Actions Status](https://github.com/pires/go-proxyproto/workflows/test/badge.svg)](https://github.com/pires/go-proxyproto/actions) +[![Coverage Status](https://coveralls.io/repos/github/pires/go-proxyproto/badge.svg?branch=master)](https://coveralls.io/github/pires/go-proxyproto?branch=master) +[![Go Report Card](https://goreportcard.com/badge/github.com/pires/go-proxyproto)](https://goreportcard.com/report/github.com/pires/go-proxyproto) +[![](https://godoc.org/github.com/pires/go-proxyproto?status.svg)](https://pkg.go.dev/github.com/pires/go-proxyproto?tab=doc) + + +A Go library implementation of the [PROXY protocol, versions 1 and 2](https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt), +which provides, as per specification: +> (...) a convenient way to safely transport connection +> information such as a client's address across multiple layers of NAT or TCP +> proxies. It is designed to require little changes to existing components and +> to limit the performance impact caused by the processing of the transported +> information. + +This library is to be used in one of or both proxy clients and proxy servers that need to support said protocol. +Both protocol versions, 1 (text-based) and 2 (binary-based) are supported. + +## Installation + +```shell +$ go get -u github.com/pires/go-proxyproto +``` + +## Usage + +### Client + +```go +package main + +import ( + "io" + "log" + "net" + + proxyproto "github.com/pires/go-proxyproto" +) + +func chkErr(err error) { + if err != nil { + log.Fatalf("Error: %s", err.Error()) + } +} + +func main() { + // Dial some proxy listener e.g. https://github.com/mailgun/proxyproto + target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:2319") + chkErr(err) + + conn, err := net.DialTCP("tcp", nil, target) + chkErr(err) + + defer conn.Close() + + // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you + // have two conn's + header := &proxyproto.Header{ + Version: 1, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP("10.1.1.1"), + Port: 1000, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP("20.2.2.2"), + Port: 2000, + }, + } + // After the connection was created write the proxy headers first + _, err = header.WriteTo(conn) + chkErr(err) + // Then your data... e.g.: + _, err = io.WriteString(conn, "HELO") + chkErr(err) +} +``` + +### Server + +```go +package main + +import ( + "log" + "net" + + proxyproto "github.com/pires/go-proxyproto" +) + +func main() { + // Create a listener + addr := "localhost:9876" + list, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error()) + } + + // Wrap listener in a proxyproto listener + proxyListener := &proxyproto.Listener{Listener: list} + defer proxyListener.Close() + + // Wait for a connection and accept it + conn, err := proxyListener.Accept() + defer conn.Close() + + // Print connection details + if conn.LocalAddr() == nil { + log.Fatal("couldn't retrieve local address") + } + log.Printf("local address: %q", conn.LocalAddr().String()) + + if conn.RemoteAddr() == nil { + log.Fatal("couldn't retrieve remote address") + } + log.Printf("remote address: %q", conn.RemoteAddr().String()) +} +``` + +### HTTP Server +```go +package main + +import ( + "net" + "net/http" + "time" + + "github.com/pires/go-proxyproto" +) + +func main() { + server := http.Server{ + Addr: ":8080", + } + + ln, err := net.Listen("tcp", server.Addr) + if err != nil { + panic(err) + } + + proxyListener := &proxyproto.Listener{ + Listener: ln, + ReadHeaderTimeout: 10 * time.Second, + } + defer proxyListener.Close() + + server.Serve(proxyListener) +} +``` + +## Special notes + +### AWS + +AWS Network Load Balancer (NLB) does not push the PPV2 header until the client starts sending the data. This is a problem if your server speaks first. e.g. SMTP, FTP, SSH etc. + +By default, NLB target group attribute `proxy_protocol_v2.client_to_server.header_placement` has the value `on_first_ack_with_payload`. You need to contact AWS support to change it to `on_first_ack`, instead. + +Just to be clear, you need this fix only if your server is designed to speak first. diff --git a/vendor/github.com/pires/go-proxyproto/addr_proto.go b/vendor/github.com/pires/go-proxyproto/addr_proto.go new file mode 100644 index 000000000000..d254fc41317c --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/addr_proto.go @@ -0,0 +1,62 @@ +package proxyproto + +// AddressFamilyAndProtocol represents address family and transport protocol. +type AddressFamilyAndProtocol byte + +const ( + UNSPEC AddressFamilyAndProtocol = '\x00' + TCPv4 AddressFamilyAndProtocol = '\x11' + UDPv4 AddressFamilyAndProtocol = '\x12' + TCPv6 AddressFamilyAndProtocol = '\x21' + UDPv6 AddressFamilyAndProtocol = '\x22' + UnixStream AddressFamilyAndProtocol = '\x31' + UnixDatagram AddressFamilyAndProtocol = '\x32' +) + +// IsIPv4 returns true if the address family is IPv4 (AF_INET4), false otherwise. +func (ap AddressFamilyAndProtocol) IsIPv4() bool { + return ap&0xF0 == 0x10 +} + +// IsIPv6 returns true if the address family is IPv6 (AF_INET6), false otherwise. +func (ap AddressFamilyAndProtocol) IsIPv6() bool { + return ap&0xF0 == 0x20 +} + +// IsUnix returns true if the address family is UNIX (AF_UNIX), false otherwise. +func (ap AddressFamilyAndProtocol) IsUnix() bool { + return ap&0xF0 == 0x30 +} + +// IsStream returns true if the transport protocol is TCP or STREAM (SOCK_STREAM), false otherwise. +func (ap AddressFamilyAndProtocol) IsStream() bool { + return ap&0x0F == 0x01 +} + +// IsDatagram returns true if the transport protocol is UDP or DGRAM (SOCK_DGRAM), false otherwise. +func (ap AddressFamilyAndProtocol) IsDatagram() bool { + return ap&0x0F == 0x02 +} + +// IsUnspec returns true if the transport protocol or address family is unspecified, false otherwise. +func (ap AddressFamilyAndProtocol) IsUnspec() bool { + return (ap&0xF0 == 0x00) || (ap&0x0F == 0x00) +} + +func (ap AddressFamilyAndProtocol) toByte() byte { + if ap.IsIPv4() && ap.IsStream() { + return byte(TCPv4) + } else if ap.IsIPv4() && ap.IsDatagram() { + return byte(UDPv4) + } else if ap.IsIPv6() && ap.IsStream() { + return byte(TCPv6) + } else if ap.IsIPv6() && ap.IsDatagram() { + return byte(UDPv6) + } else if ap.IsUnix() && ap.IsStream() { + return byte(UnixStream) + } else if ap.IsUnix() && ap.IsDatagram() { + return byte(UnixDatagram) + } + + return byte(UNSPEC) +} diff --git a/vendor/github.com/pires/go-proxyproto/header.go b/vendor/github.com/pires/go-proxyproto/header.go new file mode 100644 index 000000000000..81ebeb387eb1 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/header.go @@ -0,0 +1,280 @@ +// Package proxyproto implements Proxy Protocol (v1 and v2) parser and writer, as per specification: +// https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt +package proxyproto + +import ( + "bufio" + "bytes" + "errors" + "io" + "net" + "time" +) + +var ( + // Protocol + SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'} + SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'} + + ErrCantReadVersion1Header = errors.New("proxyproto: can't read version 1 header") + ErrVersion1HeaderTooLong = errors.New("proxyproto: version 1 header must be 107 bytes or less") + ErrLineMustEndWithCrlf = errors.New("proxyproto: version 1 header is invalid, must end with \\r\\n") + ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command") + ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol") + ErrCantReadLength = errors.New("proxyproto: can't read length") + ErrCantResolveSourceUnixAddress = errors.New("proxyproto: can't resolve source Unix address") + ErrCantResolveDestinationUnixAddress = errors.New("proxyproto: can't resolve destination Unix address") + ErrNoProxyProtocol = errors.New("proxyproto: proxy protocol signature not present") + ErrUnknownProxyProtocolVersion = errors.New("proxyproto: unknown proxy protocol version") + ErrUnsupportedProtocolVersionAndCommand = errors.New("proxyproto: unsupported proxy protocol version and command") + ErrUnsupportedAddressFamilyAndProtocol = errors.New("proxyproto: unsupported address family and protocol") + ErrInvalidLength = errors.New("proxyproto: invalid length") + ErrInvalidAddress = errors.New("proxyproto: invalid address") + ErrInvalidPortNumber = errors.New("proxyproto: invalid port number") + ErrSuperfluousProxyHeader = errors.New("proxyproto: upstream connection sent PROXY header but isn't allowed to send one") +) + +// Header is the placeholder for proxy protocol header. +type Header struct { + Version byte + Command ProtocolVersionAndCommand + TransportProtocol AddressFamilyAndProtocol + SourceAddr net.Addr + DestinationAddr net.Addr + rawTLVs []byte +} + +// HeaderProxyFromAddrs creates a new PROXY header from a source and a +// destination address. If version is zero, the latest protocol version is +// used. +// +// The header is filled on a best-effort basis: if hints cannot be inferred +// from the provided addresses, the header will be left unspecified. +func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header { + if version < 1 || version > 2 { + version = 2 + } + h := &Header{ + Version: version, + Command: LOCAL, + TransportProtocol: UNSPEC, + } + switch sourceAddr := sourceAddr.(type) { + case *net.TCPAddr: + if _, ok := destAddr.(*net.TCPAddr); !ok { + break + } + if len(sourceAddr.IP.To4()) == net.IPv4len { + h.TransportProtocol = TCPv4 + } else if len(sourceAddr.IP) == net.IPv6len { + h.TransportProtocol = TCPv6 + } + case *net.UDPAddr: + if _, ok := destAddr.(*net.UDPAddr); !ok { + break + } + if len(sourceAddr.IP.To4()) == net.IPv4len { + h.TransportProtocol = UDPv4 + } else if len(sourceAddr.IP) == net.IPv6len { + h.TransportProtocol = UDPv6 + } + case *net.UnixAddr: + if _, ok := destAddr.(*net.UnixAddr); !ok { + break + } + switch sourceAddr.Net { + case "unix": + h.TransportProtocol = UnixStream + case "unixgram": + h.TransportProtocol = UnixDatagram + } + } + if h.TransportProtocol != UNSPEC { + h.Command = PROXY + h.SourceAddr = sourceAddr + h.DestinationAddr = destAddr + } + return h +} + +func (header *Header) TCPAddrs() (sourceAddr, destAddr *net.TCPAddr, ok bool) { + if !header.TransportProtocol.IsStream() { + return nil, nil, false + } + sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) + destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) + return sourceAddr, destAddr, sourceOK && destOK +} + +func (header *Header) UDPAddrs() (sourceAddr, destAddr *net.UDPAddr, ok bool) { + if !header.TransportProtocol.IsDatagram() { + return nil, nil, false + } + sourceAddr, sourceOK := header.SourceAddr.(*net.UDPAddr) + destAddr, destOK := header.DestinationAddr.(*net.UDPAddr) + return sourceAddr, destAddr, sourceOK && destOK +} + +func (header *Header) UnixAddrs() (sourceAddr, destAddr *net.UnixAddr, ok bool) { + if !header.TransportProtocol.IsUnix() { + return nil, nil, false + } + sourceAddr, sourceOK := header.SourceAddr.(*net.UnixAddr) + destAddr, destOK := header.DestinationAddr.(*net.UnixAddr) + return sourceAddr, destAddr, sourceOK && destOK +} + +func (header *Header) IPs() (sourceIP, destIP net.IP, ok bool) { + if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { + return sourceAddr.IP, destAddr.IP, true + } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { + return sourceAddr.IP, destAddr.IP, true + } else { + return nil, nil, false + } +} + +func (header *Header) Ports() (sourcePort, destPort int, ok bool) { + if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { + return sourceAddr.Port, destAddr.Port, true + } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { + return sourceAddr.Port, destAddr.Port, true + } else { + return 0, 0, false + } +} + +// EqualTo returns true if headers are equivalent, false otherwise. +// Deprecated: use EqualsTo instead. This method will eventually be removed. +func (header *Header) EqualTo(otherHeader *Header) bool { + return header.EqualsTo(otherHeader) +} + +// EqualsTo returns true if headers are equivalent, false otherwise. +func (header *Header) EqualsTo(otherHeader *Header) bool { + if otherHeader == nil { + return false + } + // TLVs only exist for version 2 + if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) { + return false + } + if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol { + return false + } + // Return early for header with LOCAL command, which contains no address information + if header.Command == LOCAL { + return true + } + return header.SourceAddr.String() == otherHeader.SourceAddr.String() && + header.DestinationAddr.String() == otherHeader.DestinationAddr.String() +} + +// WriteTo renders a proxy protocol header in a format and writes it to an io.Writer. +func (header *Header) WriteTo(w io.Writer) (int64, error) { + buf, err := header.Format() + if err != nil { + return 0, err + } + + return bytes.NewBuffer(buf).WriteTo(w) +} + +// Format renders a proxy protocol header in a format to write over the wire. +func (header *Header) Format() ([]byte, error) { + switch header.Version { + case 1: + return header.formatVersion1() + case 2: + return header.formatVersion2() + default: + return nil, ErrUnknownProxyProtocolVersion + } +} + +// TLVs returns the TLVs stored into this header, if they exist. TLVs are optional for v2 of the protocol. +func (header *Header) TLVs() ([]TLV, error) { + return SplitTLVs(header.rawTLVs) +} + +// SetTLVs sets the TLVs stored in this header. This method replaces any +// previous TLV. +func (header *Header) SetTLVs(tlvs []TLV) error { + raw, err := JoinTLVs(tlvs) + if err != nil { + return err + } + header.rawTLVs = raw + return nil +} + +// Read identifies the proxy protocol version and reads the remaining of +// the header, accordingly. +// +// If proxy protocol header signature is not present, the reader buffer remains untouched +// and is safe for reading outside of this code. +// +// If proxy protocol header signature is present but an error is raised while processing +// the remaining header, assume the reader buffer to be in a corrupt state. +// Also, this operation will block until enough bytes are available for peeking. +func Read(reader *bufio.Reader) (*Header, error) { + // In order to improve speed for small non-PROXYed packets, take a peek at the first byte alone. + b1, err := reader.Peek(1) + if err != nil { + if err == io.EOF { + return nil, ErrNoProxyProtocol + } + return nil, err + } + + if bytes.Equal(b1[:1], SIGV1[:1]) || bytes.Equal(b1[:1], SIGV2[:1]) { + signature, err := reader.Peek(5) + if err != nil { + if err == io.EOF { + return nil, ErrNoProxyProtocol + } + return nil, err + } + if bytes.Equal(signature[:5], SIGV1) { + return parseVersion1(reader) + } + + signature, err = reader.Peek(12) + if err != nil { + if err == io.EOF { + return nil, ErrNoProxyProtocol + } + return nil, err + } + if bytes.Equal(signature[:12], SIGV2) { + return parseVersion2(reader) + } + } + + return nil, ErrNoProxyProtocol +} + +// ReadTimeout acts as Read but takes a timeout. If that timeout is reached, it's assumed +// there's no proxy protocol header. +func ReadTimeout(reader *bufio.Reader, timeout time.Duration) (*Header, error) { + type header struct { + h *Header + e error + } + read := make(chan *header, 1) + + go func() { + h := &header{} + h.h, h.e = Read(reader) + read <- h + }() + + timer := time.NewTimer(timeout) + select { + case result := <-read: + timer.Stop() + return result.h, result.e + case <-timer.C: + return nil, ErrNoProxyProtocol + } +} diff --git a/vendor/github.com/pires/go-proxyproto/policy.go b/vendor/github.com/pires/go-proxyproto/policy.go new file mode 100644 index 000000000000..6d505be4c803 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/policy.go @@ -0,0 +1,172 @@ +package proxyproto + +import ( + "fmt" + "net" + "strings" +) + +// PolicyFunc can be used to decide whether to trust the PROXY info from +// upstream. If set, the connecting address is passed in as an argument. +// +// See below for the different policies. +// +// In case an error is returned the connection is denied. +type PolicyFunc func(upstream net.Addr) (Policy, error) + +// Policy defines how a connection with a PROXY header address is treated. +type Policy int + +const ( + // USE address from PROXY header + USE Policy = iota + // IGNORE address from PROXY header, but accept connection + IGNORE + // REJECT connection when PROXY header is sent + // Note: even though the first read on the connection returns an error if + // a PROXY header is present, subsequent reads do not. It is the task of + // the code using the connection to handle that case properly. + REJECT + // REQUIRE connection to send PROXY header, reject if not present + // Note: even though the first read on the connection returns an error if + // a PROXY header is not present, subsequent reads do not. It is the task + // of the code using the connection to handle that case properly. + REQUIRE + // SKIP accepts a connection without requiring the PROXY header + // Note: an example usage can be found in the SkipProxyHeaderForCIDR + // function. + SKIP +) + +// SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a +// connection from a skipHeaderCIDR without requiring a PROXY header, e.g. +// Kubernetes pods local traffic. The def is a policy to use when an upstream +// address doesn't match the skipHeaderCIDR. +func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc { + return func(upstream net.Addr) (Policy, error) { + ip, err := ipFromAddr(upstream) + if err != nil { + return def, err + } + + if skipHeaderCIDR != nil && skipHeaderCIDR.Contains(ip) { + return SKIP, nil + } + + return def, nil + } +} + +// WithPolicy adds given policy to a connection when passed as option to NewConn() +func WithPolicy(p Policy) func(*Conn) { + return func(c *Conn) { + c.ProxyHeaderPolicy = p + } +} + +// LaxWhiteListPolicy returns a PolicyFunc which decides whether the +// upstream ip is allowed to send a proxy header based on a list of allowed +// IP addresses and IP ranges. In case upstream IP is not in list the proxy +// header will be ignored. If one of the provided IP addresses or IP ranges +// is invalid it will return an error instead of a PolicyFunc. +func LaxWhiteListPolicy(allowed []string) (PolicyFunc, error) { + allowFrom, err := parse(allowed) + if err != nil { + return nil, err + } + + return whitelistPolicy(allowFrom, IGNORE), nil +} + +// MustLaxWhiteListPolicy returns a LaxWhiteListPolicy but will panic if one +// of the provided IP addresses or IP ranges is invalid. +func MustLaxWhiteListPolicy(allowed []string) PolicyFunc { + pfunc, err := LaxWhiteListPolicy(allowed) + if err != nil { + panic(err) + } + + return pfunc +} + +// StrictWhiteListPolicy returns a PolicyFunc which decides whether the +// upstream ip is allowed to send a proxy header based on a list of allowed +// IP addresses and IP ranges. In case upstream IP is not in list reading on +// the connection will be refused on the first read. Please note: subsequent +// reads do not error. It is the task of the code using the connection to +// handle that case properly. If one of the provided IP addresses or IP +// ranges is invalid it will return an error instead of a PolicyFunc. +func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) { + allowFrom, err := parse(allowed) + if err != nil { + return nil, err + } + + return whitelistPolicy(allowFrom, REJECT), nil +} + +// MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic +// if one of the provided IP addresses or IP ranges is invalid. +func MustStrictWhiteListPolicy(allowed []string) PolicyFunc { + pfunc, err := StrictWhiteListPolicy(allowed) + if err != nil { + panic(err) + } + + return pfunc +} + +func whitelistPolicy(allowed []func(net.IP) bool, def Policy) PolicyFunc { + return func(upstream net.Addr) (Policy, error) { + upstreamIP, err := ipFromAddr(upstream) + if err != nil { + // something is wrong with the source IP, better reject the connection + return REJECT, err + } + + for _, allowFrom := range allowed { + if allowFrom(upstreamIP) { + return USE, nil + } + } + + return def, nil + } +} + +func parse(allowed []string) ([]func(net.IP) bool, error) { + a := make([]func(net.IP) bool, len(allowed)) + for i, allowFrom := range allowed { + if strings.LastIndex(allowFrom, "/") > 0 { + _, ipRange, err := net.ParseCIDR(allowFrom) + if err != nil { + return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP range: %v", allowFrom, err) + } + + a[i] = ipRange.Contains + } else { + allowed := net.ParseIP(allowFrom) + if allowed == nil { + return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP address", allowFrom) + } + + a[i] = allowed.Equal + } + } + + return a, nil +} + +func ipFromAddr(upstream net.Addr) (net.IP, error) { + upstreamString, _, err := net.SplitHostPort(upstream.String()) + if err != nil { + return nil, err + } + + upstreamIP := net.ParseIP(upstreamString) + if nil == upstreamIP { + return nil, fmt.Errorf("proxyproto: invalid IP address") + } + + return upstreamIP, nil +} diff --git a/vendor/github.com/pires/go-proxyproto/protocol.go b/vendor/github.com/pires/go-proxyproto/protocol.go new file mode 100644 index 000000000000..4ce16a2765ba --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/protocol.go @@ -0,0 +1,319 @@ +package proxyproto + +import ( + "bufio" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// DefaultReadHeaderTimeout is how long header processing waits for header to +// be read from the wire, if Listener.ReaderHeaderTimeout is not set. +// It's kept as a global variable so to make it easier to find and override, +// e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s" +var DefaultReadHeaderTimeout = 10 * time.Second + +// Listener is used to wrap an underlying listener, +// whose connections may be using the HAProxy Proxy Protocol. +// If the connection is using the protocol, the RemoteAddr() will return +// the correct client address. ReadHeaderTimeout will be applied to all +// connections in order to prevent blocking operations. If no ReadHeaderTimeout +// is set, a default of 200ms will be used. This can be disabled by setting the +// timeout to < 0. +type Listener struct { + Listener net.Listener + Policy PolicyFunc + ValidateHeader Validator + ReadHeaderTimeout time.Duration +} + +// Conn is used to wrap and underlying connection which +// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will +// return the address of the client instead of the proxy address. Each connection +// will have its own readHeaderTimeout and readDeadline set by the Accept() call. +type Conn struct { + readDeadline atomic.Value // time.Time + once sync.Once + readErr error + conn net.Conn + Validate Validator + bufReader *bufio.Reader + header *Header + ProxyHeaderPolicy Policy + readHeaderTimeout time.Duration +} + +// Validator receives a header and decides whether it is a valid one +// In case the header is not deemed valid it should return an error. +type Validator func(*Header) error + +// ValidateHeader adds given validator for proxy headers to a connection when passed as option to NewConn() +func ValidateHeader(v Validator) func(*Conn) { + return func(c *Conn) { + if v != nil { + c.Validate = v + } + } +} + +// Accept waits for and returns the next connection to the listener. +func (p *Listener) Accept() (net.Conn, error) { + // Get the underlying connection + conn, err := p.Listener.Accept() + if err != nil { + return nil, err + } + + proxyHeaderPolicy := USE + if p.Policy != nil { + proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr()) + if err != nil { + // can't decide the policy, we can't accept the connection + conn.Close() + return nil, err + } + // Handle a connection as a regular one + if proxyHeaderPolicy == SKIP { + return conn, nil + } + } + + newConn := NewConn( + conn, + WithPolicy(proxyHeaderPolicy), + ValidateHeader(p.ValidateHeader), + ) + + // If the ReadHeaderTimeout for the listener is unset, use the default timeout. + if p.ReadHeaderTimeout == 0 { + p.ReadHeaderTimeout = DefaultReadHeaderTimeout + } + + // Set the readHeaderTimeout of the new conn to the value of the listener + newConn.readHeaderTimeout = p.ReadHeaderTimeout + + return newConn, nil +} + +// Close closes the underlying listener. +func (p *Listener) Close() error { + return p.Listener.Close() +} + +// Addr returns the underlying listener's network address. +func (p *Listener) Addr() net.Addr { + return p.Listener.Addr() +} + +// NewConn is used to wrap a net.Conn that may be speaking +// the proxy protocol into a proxyproto.Conn +func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { + pConn := &Conn{ + bufReader: bufio.NewReader(conn), + conn: conn, + } + + for _, opt := range opts { + opt(pConn) + } + + return pConn +} + +// Read is check for the proxy protocol header when doing +// the initial scan. If there is an error parsing the header, +// it is returned and the socket is closed. +func (p *Conn) Read(b []byte) (int, error) { + p.once.Do(func() { + p.readErr = p.readHeader() + }) + if p.readErr != nil { + return 0, p.readErr + } + + return p.bufReader.Read(b) +} + +// Write wraps original conn.Write +func (p *Conn) Write(b []byte) (int, error) { + return p.conn.Write(b) +} + +// Close wraps original conn.Close +func (p *Conn) Close() error { + return p.conn.Close() +} + +// ProxyHeader returns the proxy protocol header, if any. If an error occurs +// while reading the proxy header, nil is returned. +func (p *Conn) ProxyHeader() *Header { + p.once.Do(func() { p.readErr = p.readHeader() }) + return p.header +} + +// LocalAddr returns the address of the server if the proxy +// protocol is being used, otherwise just returns the address of +// the socket server. In case an error happens on reading the +// proxy header the original LocalAddr is returned, not the one +// from the proxy header even if the proxy header itself is +// syntactically correct. +func (p *Conn) LocalAddr() net.Addr { + p.once.Do(func() { p.readErr = p.readHeader() }) + if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { + return p.conn.LocalAddr() + } + + return p.header.DestinationAddr +} + +// RemoteAddr returns the address of the client if the proxy +// protocol is being used, otherwise just returns the address of +// the socket peer. In case an error happens on reading the +// proxy header the original RemoteAddr is returned, not the one +// from the proxy header even if the proxy header itself is +// syntactically correct. +func (p *Conn) RemoteAddr() net.Addr { + p.once.Do(func() { p.readErr = p.readHeader() }) + if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { + return p.conn.RemoteAddr() + } + + return p.header.SourceAddr +} + +// Raw returns the underlying connection which can be casted to +// a concrete type, allowing access to specialized functions. +// +// Use this ONLY if you know exactly what you are doing. +func (p *Conn) Raw() net.Conn { + return p.conn +} + +// TCPConn returns the underlying TCP connection, +// allowing access to specialized functions. +// +// Use this ONLY if you know exactly what you are doing. +func (p *Conn) TCPConn() (conn *net.TCPConn, ok bool) { + conn, ok = p.conn.(*net.TCPConn) + return +} + +// UnixConn returns the underlying Unix socket connection, +// allowing access to specialized functions. +// +// Use this ONLY if you know exactly what you are doing. +func (p *Conn) UnixConn() (conn *net.UnixConn, ok bool) { + conn, ok = p.conn.(*net.UnixConn) + return +} + +// UDPConn returns the underlying UDP connection, +// allowing access to specialized functions. +// +// Use this ONLY if you know exactly what you are doing. +func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) { + conn, ok = p.conn.(*net.UDPConn) + return +} + +// SetDeadline wraps original conn.SetDeadline +func (p *Conn) SetDeadline(t time.Time) error { + p.readDeadline.Store(t) + return p.conn.SetDeadline(t) +} + +// SetReadDeadline wraps original conn.SetReadDeadline +func (p *Conn) SetReadDeadline(t time.Time) error { + // Set a local var that tells us the desired deadline. This is + // needed in order to reset the read deadline to the one that is + // desired by the user, rather than an empty deadline. + p.readDeadline.Store(t) + return p.conn.SetReadDeadline(t) +} + +// SetWriteDeadline wraps original conn.SetWriteDeadline +func (p *Conn) SetWriteDeadline(t time.Time) error { + return p.conn.SetWriteDeadline(t) +} + +func (p *Conn) readHeader() error { + // If the connection's readHeaderTimeout is more than 0, + // push our deadline back to now plus the timeout. This should only + // run on the connection, as we don't want to override the previous + // read deadline the user may have used. + if p.readHeaderTimeout > 0 { + if err := p.conn.SetReadDeadline(time.Now().Add(p.readHeaderTimeout)); err != nil { + return err + } + } + + header, err := Read(p.bufReader) + + // If the connection's readHeaderTimeout is more than 0, undo the change to the + // deadline that we made above. Because we retain the readDeadline as part of our + // SetReadDeadline override, we know the user's desired deadline so we use that. + // Therefore, we check whether the error is a net.Timeout and if it is, we decide + // the proxy proto does not exist and set the error accordingly. + if p.readHeaderTimeout > 0 { + t := p.readDeadline.Load() + if t == nil { + t = time.Time{} + } + if err := p.conn.SetReadDeadline(t.(time.Time)); err != nil { + return err + } + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + err = ErrNoProxyProtocol + } + } + + // For the purpose of this wrapper shamefully stolen from armon/go-proxyproto + // let's act as if there was no error when PROXY protocol is not present. + if err == ErrNoProxyProtocol { + // but not if it is required that the connection has one + if p.ProxyHeaderPolicy == REQUIRE { + return err + } + + return nil + } + + // proxy protocol header was found + if err == nil && header != nil { + switch p.ProxyHeaderPolicy { + case REJECT: + // this connection is not allowed to send one + return ErrSuperfluousProxyHeader + case USE, REQUIRE: + if p.Validate != nil { + err = p.Validate(header) + if err != nil { + return err + } + } + + p.header = header + } + } + + return err +} + +// ReadFrom implements the io.ReaderFrom ReadFrom method +func (p *Conn) ReadFrom(r io.Reader) (int64, error) { + if rf, ok := p.conn.(io.ReaderFrom); ok { + return rf.ReadFrom(r) + } + return io.Copy(p.conn, r) +} + +// WriteTo implements io.WriterTo +func (p *Conn) WriteTo(w io.Writer) (int64, error) { + p.once.Do(func() { p.readErr = p.readHeader() }) + if p.readErr != nil { + return 0, p.readErr + } + return p.bufReader.WriteTo(w) +} diff --git a/vendor/github.com/pires/go-proxyproto/tlv.go b/vendor/github.com/pires/go-proxyproto/tlv.go new file mode 100644 index 000000000000..7cc2fb376ed7 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/tlv.go @@ -0,0 +1,132 @@ +// Type-Length-Value splitting and parsing for proxy protocol V2 +// See spec https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt sections 2.2 to 2.7 and + +package proxyproto + +import ( + "encoding/binary" + "errors" + "fmt" + "math" +) + +const ( + // Section 2.2 + PP2_TYPE_ALPN PP2Type = 0x01 + PP2_TYPE_AUTHORITY PP2Type = 0x02 + PP2_TYPE_CRC32C PP2Type = 0x03 + PP2_TYPE_NOOP PP2Type = 0x04 + PP2_TYPE_UNIQUE_ID PP2Type = 0x05 + PP2_TYPE_SSL PP2Type = 0x20 + PP2_SUBTYPE_SSL_VERSION PP2Type = 0x21 + PP2_SUBTYPE_SSL_CN PP2Type = 0x22 + PP2_SUBTYPE_SSL_CIPHER PP2Type = 0x23 + PP2_SUBTYPE_SSL_SIG_ALG PP2Type = 0x24 + PP2_SUBTYPE_SSL_KEY_ALG PP2Type = 0x25 + PP2_TYPE_NETNS PP2Type = 0x30 + + // Section 2.2.7, reserved types + PP2_TYPE_MIN_CUSTOM PP2Type = 0xE0 + PP2_TYPE_MAX_CUSTOM PP2Type = 0xEF + PP2_TYPE_MIN_EXPERIMENT PP2Type = 0xF0 + PP2_TYPE_MAX_EXPERIMENT PP2Type = 0xF7 + PP2_TYPE_MIN_FUTURE PP2Type = 0xF8 + PP2_TYPE_MAX_FUTURE PP2Type = 0xFF +) + +var ( + ErrTruncatedTLV = errors.New("proxyproto: truncated TLV") + ErrMalformedTLV = errors.New("proxyproto: malformed TLV Value") + ErrIncompatibleTLV = errors.New("proxyproto: incompatible TLV type") +) + +// PP2Type is the proxy protocol v2 type +type PP2Type byte + +// TLV is a uninterpreted Type-Length-Value for V2 protocol, see section 2.2 +type TLV struct { + Type PP2Type + Value []byte +} + +// SplitTLVs splits the Type-Length-Value vector, returns the vector or an error. +func SplitTLVs(raw []byte) ([]TLV, error) { + var tlvs []TLV + for i := 0; i < len(raw); { + tlv := TLV{ + Type: PP2Type(raw[i]), + } + if len(raw)-i <= 2 { + return nil, ErrTruncatedTLV + } + tlvLen := int(binary.BigEndian.Uint16(raw[i+1 : i+3])) // Max length = 65K + i += 3 + if i+tlvLen > len(raw) { + return nil, ErrTruncatedTLV + } + // Ignore no-op padding + if tlv.Type != PP2_TYPE_NOOP { + tlv.Value = make([]byte, tlvLen) + copy(tlv.Value, raw[i:i+tlvLen]) + } + i += tlvLen + tlvs = append(tlvs, tlv) + } + return tlvs, nil +} + +// JoinTLVs joins multiple Type-Length-Value records. +func JoinTLVs(tlvs []TLV) ([]byte, error) { + var raw []byte + for _, tlv := range tlvs { + if len(tlv.Value) > math.MaxUint16 { + return nil, fmt.Errorf("proxyproto: cannot format TLV %v with length %d", tlv.Type, len(tlv.Value)) + } + var length [2]byte + binary.BigEndian.PutUint16(length[:], uint16(len(tlv.Value))) + raw = append(raw, byte(tlv.Type)) + raw = append(raw, length[:]...) + raw = append(raw, tlv.Value...) + } + return raw, nil +} + +// Registered is true if the type is registered in the spec, see section 2.2 +func (p PP2Type) Registered() bool { + switch p { + case PP2_TYPE_ALPN, + PP2_TYPE_AUTHORITY, + PP2_TYPE_CRC32C, + PP2_TYPE_NOOP, + PP2_TYPE_UNIQUE_ID, + PP2_TYPE_SSL, + PP2_SUBTYPE_SSL_VERSION, + PP2_SUBTYPE_SSL_CN, + PP2_SUBTYPE_SSL_CIPHER, + PP2_SUBTYPE_SSL_SIG_ALG, + PP2_SUBTYPE_SSL_KEY_ALG, + PP2_TYPE_NETNS: + return true + } + return false +} + +// App is true if the type is reserved for application specific data, see section 2.2.7 +func (p PP2Type) App() bool { + return p >= PP2_TYPE_MIN_CUSTOM && p <= PP2_TYPE_MAX_CUSTOM +} + +// Experiment is true if the type is reserved for temporary experimental use by application developers, see section 2.2.7 +func (p PP2Type) Experiment() bool { + return p >= PP2_TYPE_MIN_EXPERIMENT && p <= PP2_TYPE_MAX_EXPERIMENT +} + +// Future is true is the type is reserved for future use, see section 2.2.7 +func (p PP2Type) Future() bool { + return p >= PP2_TYPE_MIN_FUTURE +} + +// Spec is true if the type is covered by the spec, see section 2.2 and 2.2.7 +func (p PP2Type) Spec() bool { + return p.Registered() || p.App() || p.Experiment() || p.Future() +} diff --git a/vendor/github.com/pires/go-proxyproto/v1.go b/vendor/github.com/pires/go-proxyproto/v1.go new file mode 100644 index 000000000000..0d34ba5264e5 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/v1.go @@ -0,0 +1,243 @@ +package proxyproto + +import ( + "bufio" + "bytes" + "fmt" + "net" + "net/netip" + "strconv" + "strings" +) + +const ( + crlf = "\r\n" + separator = " " +) + +func initVersion1() *Header { + header := new(Header) + header.Version = 1 + // Command doesn't exist in v1 + header.Command = PROXY + return header +} + +func parseVersion1(reader *bufio.Reader) (*Header, error) { + //The header cannot be more than 107 bytes long. Per spec: + // + // (...) + // - worst case (optional fields set to 0xff) : + // "PROXY UNKNOWN ffff:f...f:ffff ffff:f...f:ffff 65535 65535\r\n" + // => 5 + 1 + 7 + 1 + 39 + 1 + 39 + 1 + 5 + 1 + 5 + 2 = 107 chars + // + // So a 108-byte buffer is always enough to store all the line and a + // trailing zero for string processing. + // + // It must also be CRLF terminated, as above. The header does not otherwise + // contain a CR or LF byte. + // + // ISSUE #69 + // We can't use Peek here as it will block trying to fill the buffer, which + // will never happen if the header is TCP4 or TCP6 (max. 56 and 104 bytes + // respectively) and the server is expected to speak first. + // + // Similarly, we can't use ReadString or ReadBytes as these will keep reading + // until the delimiter is found; an abusive client could easily disrupt a + // server by sending a large amount of data that do not contain a LF byte. + // Another means of attack would be to start connections and simply not send + // data after the initial PROXY signature bytes, accumulating a large + // number of blocked goroutines on the server. ReadSlice will also block for + // a delimiter when the internal buffer does not fill up. + // + // A plain Read is also problematic since we risk reading past the end of the + // header without being able to easily put the excess bytes back into the reader's + // buffer (with the current implementation's design). + // + // So we use a ReadByte loop, which solves the overflow problem and avoids + // reading beyond the end of the header. However, we need one more trick to harden + // against partial header attacks (slow loris) - per spec: + // + // (..) The sender must always ensure that the header is sent at once, so that + // the transport layer maintains atomicity along the path to the receiver. The + // receiver may be tolerant to partial headers or may simply drop the connection + // when receiving a partial header. Recommendation is to be tolerant, but + // implementation constraints may not always easily permit this. + // + // We are subject to such implementation constraints. So we return an error if + // the header cannot be fully extracted with a single read of the underlying + // reader. + buf := make([]byte, 0, 107) + for { + b, err := reader.ReadByte() + if err != nil { + return nil, fmt.Errorf(ErrCantReadVersion1Header.Error()+": %v", err) + } + buf = append(buf, b) + if b == '\n' { + // End of header found + break + } + if len(buf) == 107 { + // No delimiter in first 107 bytes + return nil, ErrVersion1HeaderTooLong + } + if reader.Buffered() == 0 { + // Header was not buffered in a single read. Since we can't + // differentiate between genuine slow writers and DoS agents, + // we abort. On healthy networks, this should never happen. + return nil, ErrCantReadVersion1Header + } + } + + // Check for CR before LF. + if len(buf) < 2 || buf[len(buf)-2] != '\r' { + return nil, ErrLineMustEndWithCrlf + } + + // Check full signature. + tokens := strings.Split(string(buf[:len(buf)-2]), separator) + + // Expect at least 2 tokens: "PROXY" and the transport protocol. + if len(tokens) < 2 { + return nil, ErrCantReadAddressFamilyAndProtocol + } + + // Read address family and protocol + var transportProtocol AddressFamilyAndProtocol + switch tokens[1] { + case "TCP4": + transportProtocol = TCPv4 + case "TCP6": + transportProtocol = TCPv6 + case "UNKNOWN": + transportProtocol = UNSPEC // doesn't exist in v1 but fits UNKNOWN + default: + return nil, ErrCantReadAddressFamilyAndProtocol + } + + // Expect 6 tokens only when UNKNOWN is not present. + if transportProtocol != UNSPEC && len(tokens) < 6 { + return nil, ErrCantReadAddressFamilyAndProtocol + } + + // When a signature is found, allocate a v1 header with Command set to PROXY. + // Command doesn't exist in v1 but set it for other parts of this library + // to rely on it for determining connection details. + header := initVersion1() + + // Transport protocol has been processed already. + header.TransportProtocol = transportProtocol + + // When UNKNOWN, set the command to LOCAL and return early + if header.TransportProtocol == UNSPEC { + header.Command = LOCAL + return header, nil + } + + // Otherwise, continue to read addresses and ports + sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2]) + if err != nil { + return nil, err + } + destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3]) + if err != nil { + return nil, err + } + sourcePort, err := parseV1PortNumber(tokens[4]) + if err != nil { + return nil, err + } + destPort, err := parseV1PortNumber(tokens[5]) + if err != nil { + return nil, err + } + header.SourceAddr = &net.TCPAddr{ + IP: sourceIP, + Port: sourcePort, + } + header.DestinationAddr = &net.TCPAddr{ + IP: destIP, + Port: destPort, + } + + return header, nil +} + +func (header *Header) formatVersion1() ([]byte, error) { + // As of version 1, only "TCP4" ( \x54 \x43 \x50 \x34 ) for TCP over IPv4, + // and "TCP6" ( \x54 \x43 \x50 \x36 ) for TCP over IPv6 are allowed. + var proto string + switch header.TransportProtocol { + case TCPv4: + proto = "TCP4" + case TCPv6: + proto = "TCP6" + default: + // Unknown connection (short form) + return []byte("PROXY UNKNOWN" + crlf), nil + } + + sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) + destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) + if !sourceOK || !destOK { + return nil, ErrInvalidAddress + } + + sourceIP, destIP := sourceAddr.IP, destAddr.IP + switch header.TransportProtocol { + case TCPv4: + sourceIP = sourceIP.To4() + destIP = destIP.To4() + case TCPv6: + sourceIP = sourceIP.To16() + destIP = destIP.To16() + } + if sourceIP == nil || destIP == nil { + return nil, ErrInvalidAddress + } + + buf := bytes.NewBuffer(make([]byte, 0, 108)) + buf.Write(SIGV1) + buf.WriteString(separator) + buf.WriteString(proto) + buf.WriteString(separator) + buf.WriteString(sourceIP.String()) + buf.WriteString(separator) + buf.WriteString(destIP.String()) + buf.WriteString(separator) + buf.WriteString(strconv.Itoa(sourceAddr.Port)) + buf.WriteString(separator) + buf.WriteString(strconv.Itoa(destAddr.Port)) + buf.WriteString(crlf) + + return buf.Bytes(), nil +} + +func parseV1PortNumber(portStr string) (int, error) { + port, err := strconv.Atoi(portStr) + if err != nil || port < 0 || port > 65535 { + return 0, ErrInvalidPortNumber + } + return port, nil +} + +func parseV1IPAddress(protocol AddressFamilyAndProtocol, addrStr string) (net.IP, error) { + addr, err := netip.ParseAddr(addrStr) + if err != nil { + return nil, ErrInvalidAddress + } + + switch protocol { + case TCPv4: + if addr.Is4() { + return net.IP(addr.AsSlice()), nil + } + case TCPv6: + if addr.Is6() || addr.Is4In6() { + return net.IP(addr.AsSlice()), nil + } + } + + return nil, ErrInvalidAddress +} diff --git a/vendor/github.com/pires/go-proxyproto/v2.go b/vendor/github.com/pires/go-proxyproto/v2.go new file mode 100644 index 000000000000..74bf3f077145 --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/v2.go @@ -0,0 +1,285 @@ +package proxyproto + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "io" + "net" +) + +var ( + lengthUnspec = uint16(0) + lengthV4 = uint16(12) + lengthV6 = uint16(36) + lengthUnix = uint16(216) + lengthUnspecBytes = func() []byte { + a := make([]byte, 2) + binary.BigEndian.PutUint16(a, lengthUnspec) + return a + }() + lengthV4Bytes = func() []byte { + a := make([]byte, 2) + binary.BigEndian.PutUint16(a, lengthV4) + return a + }() + lengthV6Bytes = func() []byte { + a := make([]byte, 2) + binary.BigEndian.PutUint16(a, lengthV6) + return a + }() + lengthUnixBytes = func() []byte { + a := make([]byte, 2) + binary.BigEndian.PutUint16(a, lengthUnix) + return a + }() + errUint16Overflow = errors.New("proxyproto: uint16 overflow") +) + +type _ports struct { + SrcPort uint16 + DstPort uint16 +} + +type _addr4 struct { + Src [4]byte + Dst [4]byte + SrcPort uint16 + DstPort uint16 +} + +type _addr6 struct { + Src [16]byte + Dst [16]byte + _ports +} + +type _addrUnix struct { + Src [108]byte + Dst [108]byte +} + +func parseVersion2(reader *bufio.Reader) (header *Header, err error) { + // Skip first 12 bytes (signature) + for i := 0; i < 12; i++ { + if _, err = reader.ReadByte(); err != nil { + return nil, ErrCantReadProtocolVersionAndCommand + } + } + + header = new(Header) + header.Version = 2 + + // Read the 13th byte, protocol version and command + b13, err := reader.ReadByte() + if err != nil { + return nil, ErrCantReadProtocolVersionAndCommand + } + header.Command = ProtocolVersionAndCommand(b13) + if _, ok := supportedCommand[header.Command]; !ok { + return nil, ErrUnsupportedProtocolVersionAndCommand + } + + // Read the 14th byte, address family and protocol + b14, err := reader.ReadByte() + if err != nil { + return nil, ErrCantReadAddressFamilyAndProtocol + } + header.TransportProtocol = AddressFamilyAndProtocol(b14) + // UNSPEC is only supported when LOCAL is set. + if header.TransportProtocol == UNSPEC && header.Command != LOCAL { + return nil, ErrUnsupportedAddressFamilyAndProtocol + } + + // Make sure there are bytes available as specified in length + var length uint16 + if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil { + return nil, ErrCantReadLength + } + if !header.validateLength(length) { + return nil, ErrInvalidLength + } + + // Return early if the length is zero, which means that + // there's no address information and TLVs present for UNSPEC. + if length == 0 { + return header, nil + } + + if _, err := reader.Peek(int(length)); err != nil { + return nil, ErrInvalidLength + } + + // Length-limited reader for payload section + payloadReader := io.LimitReader(reader, int64(length)).(*io.LimitedReader) + + // Read addresses and ports for protocols other than UNSPEC. + // Ignore address information for UNSPEC, and skip straight to read TLVs, + // since the length is greater than zero. + if header.TransportProtocol != UNSPEC { + if header.TransportProtocol.IsIPv4() { + var addr _addr4 + if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { + return nil, ErrInvalidAddress + } + header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) + header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) + } else if header.TransportProtocol.IsIPv6() { + var addr _addr6 + if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { + return nil, ErrInvalidAddress + } + header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) + header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) + } else if header.TransportProtocol.IsUnix() { + var addr _addrUnix + if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { + return nil, ErrInvalidAddress + } + + network := "unix" + if header.TransportProtocol.IsDatagram() { + network = "unixgram" + } + + header.SourceAddr = &net.UnixAddr{ + Net: network, + Name: parseUnixName(addr.Src[:]), + } + header.DestinationAddr = &net.UnixAddr{ + Net: network, + Name: parseUnixName(addr.Dst[:]), + } + } + } + + // Copy bytes for optional Type-Length-Value vector + header.rawTLVs = make([]byte, payloadReader.N) // Allocate minimum size slice + if _, err = io.ReadFull(payloadReader, header.rawTLVs); err != nil && err != io.EOF { + return nil, err + } + + return header, nil +} + +func (header *Header) formatVersion2() ([]byte, error) { + var buf bytes.Buffer + buf.Write(SIGV2) + buf.WriteByte(header.Command.toByte()) + buf.WriteByte(header.TransportProtocol.toByte()) + if header.TransportProtocol.IsUnspec() { + // For UNSPEC, write no addresses and ports but only TLVs if they are present + hdrLen, err := addTLVLen(lengthUnspecBytes, len(header.rawTLVs)) + if err != nil { + return nil, err + } + buf.Write(hdrLen) + } else { + var addrSrc, addrDst []byte + if header.TransportProtocol.IsIPv4() { + hdrLen, err := addTLVLen(lengthV4Bytes, len(header.rawTLVs)) + if err != nil { + return nil, err + } + buf.Write(hdrLen) + sourceIP, destIP, _ := header.IPs() + addrSrc = sourceIP.To4() + addrDst = destIP.To4() + } else if header.TransportProtocol.IsIPv6() { + hdrLen, err := addTLVLen(lengthV6Bytes, len(header.rawTLVs)) + if err != nil { + return nil, err + } + buf.Write(hdrLen) + sourceIP, destIP, _ := header.IPs() + addrSrc = sourceIP.To16() + addrDst = destIP.To16() + } else if header.TransportProtocol.IsUnix() { + buf.Write(lengthUnixBytes) + sourceAddr, destAddr, ok := header.UnixAddrs() + if !ok { + return nil, ErrInvalidAddress + } + addrSrc = formatUnixName(sourceAddr.Name) + addrDst = formatUnixName(destAddr.Name) + } + + if addrSrc == nil || addrDst == nil { + return nil, ErrInvalidAddress + } + buf.Write(addrSrc) + buf.Write(addrDst) + + if sourcePort, destPort, ok := header.Ports(); ok { + portBytes := make([]byte, 2) + + binary.BigEndian.PutUint16(portBytes, uint16(sourcePort)) + buf.Write(portBytes) + + binary.BigEndian.PutUint16(portBytes, uint16(destPort)) + buf.Write(portBytes) + } + } + + if len(header.rawTLVs) > 0 { + buf.Write(header.rawTLVs) + } + + return buf.Bytes(), nil +} + +func (header *Header) validateLength(length uint16) bool { + if header.TransportProtocol.IsIPv4() { + return length >= lengthV4 + } else if header.TransportProtocol.IsIPv6() { + return length >= lengthV6 + } else if header.TransportProtocol.IsUnix() { + return length >= lengthUnix + } else if header.TransportProtocol.IsUnspec() { + return length >= lengthUnspec + } + return false +} + +// addTLVLen adds the length of the TLV to the header length or errors on uint16 overflow. +func addTLVLen(cur []byte, tlvLen int) ([]byte, error) { + if tlvLen == 0 { + return cur, nil + } + curLen := binary.BigEndian.Uint16(cur) + newLen := int(curLen) + tlvLen + if newLen >= 1<<16 { + return nil, errUint16Overflow + } + a := make([]byte, 2) + binary.BigEndian.PutUint16(a, uint16(newLen)) + return a, nil +} + +func newIPAddr(transport AddressFamilyAndProtocol, ip net.IP, port uint16) net.Addr { + if transport.IsStream() { + return &net.TCPAddr{IP: ip, Port: int(port)} + } else if transport.IsDatagram() { + return &net.UDPAddr{IP: ip, Port: int(port)} + } else { + return nil + } +} + +func parseUnixName(b []byte) string { + i := bytes.IndexByte(b, 0) + if i < 0 { + return string(b) + } + return string(b[:i]) +} + +func formatUnixName(name string) []byte { + n := int(lengthUnix) / 2 + if len(name) >= n { + return []byte(name[:n]) + } + pad := make([]byte, n-len(name)) + return append([]byte(name), pad...) +} diff --git a/vendor/github.com/pires/go-proxyproto/version_cmd.go b/vendor/github.com/pires/go-proxyproto/version_cmd.go new file mode 100644 index 000000000000..59f20420882a --- /dev/null +++ b/vendor/github.com/pires/go-proxyproto/version_cmd.go @@ -0,0 +1,47 @@ +package proxyproto + +// ProtocolVersionAndCommand represents the command in proxy protocol v2. +// Command doesn't exist in v1 but it should be set since other parts of +// this library may rely on it for determining connection details. +type ProtocolVersionAndCommand byte + +const ( + // LOCAL represents the LOCAL command in v2 or UNKNOWN transport in v1, + // in which case no address information is expected. + LOCAL ProtocolVersionAndCommand = '\x20' + // PROXY represents the PROXY command in v2 or transport is not UNKNOWN in v1, + // in which case valid local/remote address and port information is expected. + PROXY ProtocolVersionAndCommand = '\x21' +) + +var supportedCommand = map[ProtocolVersionAndCommand]bool{ + LOCAL: true, + PROXY: true, +} + +// IsLocal returns true if the command in v2 is LOCAL or the transport in v1 is UNKNOWN, +// i.e. when no address information is expected, false otherwise. +func (pvc ProtocolVersionAndCommand) IsLocal() bool { + return LOCAL == pvc +} + +// IsProxy returns true if the command in v2 is PROXY or the transport in v1 is not UNKNOWN, +// i.e. when valid local/remote address and port information is expected, false otherwise. +func (pvc ProtocolVersionAndCommand) IsProxy() bool { + return PROXY == pvc +} + +// IsUnspec returns true if the command is unspecified, false otherwise. +func (pvc ProtocolVersionAndCommand) IsUnspec() bool { + return !(pvc.IsLocal() || pvc.IsProxy()) +} + +func (pvc ProtocolVersionAndCommand) toByte() byte { + if pvc.IsLocal() { + return byte(LOCAL) + } else if pvc.IsProxy() { + return byte(PROXY) + } + + return byte(LOCAL) +} diff --git a/vendor/github.com/soheilhy/cmux/.gitignore b/vendor/github.com/soheilhy/cmux/.gitignore deleted file mode 100644 index daf913b1b347..000000000000 --- a/vendor/github.com/soheilhy/cmux/.gitignore +++ /dev/null @@ -1,24 +0,0 @@ -# Compiled Object files, Static and Dynamic libs (Shared Objects) -*.o -*.a -*.so - -# Folders -_obj -_test - -# Architecture specific extensions/prefixes -*.[568vq] -[568vq].out - -*.cgo1.go -*.cgo2.c -_cgo_defun.c -_cgo_gotypes.go -_cgo_export.* - -_testmain.go - -*.exe -*.test -*.prof diff --git a/vendor/github.com/soheilhy/cmux/.travis.yml b/vendor/github.com/soheilhy/cmux/.travis.yml deleted file mode 100644 index 4d78a519feb6..000000000000 --- a/vendor/github.com/soheilhy/cmux/.travis.yml +++ /dev/null @@ -1,29 +0,0 @@ -language: go - -go: - - 1.6 - - 1.7 - - 1.8 - - tip - -matrix: - allow_failures: - - go: tip - -gobuild_args: -race - -before_install: - - if [[ $TRAVIS_GO_VERSION == 1.6* ]]; then go get -u github.com/kisielk/errcheck; fi - - if [[ $TRAVIS_GO_VERSION == 1.6* ]]; then go get -u golang.org/x/lint/golint; fi - -before_script: - - '! gofmt -s -l . | read' - - echo $TRAVIS_GO_VERSION - - if [[ $TRAVIS_GO_VERSION == 1.6* ]]; then golint ./...; fi - - if [[ $TRAVIS_GO_VERSION == 1.6* ]]; then errcheck ./...; fi - - if [[ $TRAVIS_GO_VERSION == 1.6* ]]; then go tool vet .; fi - - if [[ $TRAVIS_GO_VERSION == 1.6* ]]; then go tool vet --shadow .; fi - -script: - - go test -bench . -v ./... - - go test -race -bench . -v ./... diff --git a/vendor/github.com/soheilhy/cmux/CONTRIBUTORS b/vendor/github.com/soheilhy/cmux/CONTRIBUTORS deleted file mode 100644 index 49878f228a12..000000000000 --- a/vendor/github.com/soheilhy/cmux/CONTRIBUTORS +++ /dev/null @@ -1,12 +0,0 @@ -# The list of people who have contributed code to the cmux repository. -# -# Auto-generated with: -# git log --oneline --pretty=format:'%an <%aE>' | sort -u -# -Andreas Jaekle -Dmitri Shuralyov -Ethan Mosbaugh -Soheil Hassas Yeganeh -Soheil Hassas Yeganeh -Tamir Duberstein -Tamir Duberstein diff --git a/vendor/github.com/soheilhy/cmux/README.md b/vendor/github.com/soheilhy/cmux/README.md deleted file mode 100644 index c4191b70b003..000000000000 --- a/vendor/github.com/soheilhy/cmux/README.md +++ /dev/null @@ -1,83 +0,0 @@ -# cmux: Connection Mux ![Travis Build Status](https://api.travis-ci.org/soheilhy/args.svg?branch=master "Travis Build Status") [![GoDoc](https://godoc.org/github.com/soheilhy/cmux?status.svg)](http://godoc.org/github.com/soheilhy/cmux) - -cmux is a generic Go library to multiplex connections based on -their payload. Using cmux, you can serve gRPC, SSH, HTTPS, HTTP, -Go RPC, and pretty much any other protocol on the same TCP listener. - -## How-To -Simply create your main listener, create a cmux for that listener, -and then match connections: -```go -// Create the main listener. -l, err := net.Listen("tcp", ":23456") -if err != nil { - log.Fatal(err) -} - -// Create a cmux. -m := cmux.New(l) - -// Match connections in order: -// First grpc, then HTTP, and otherwise Go RPC/TCP. -grpcL := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) -httpL := m.Match(cmux.HTTP1Fast()) -trpcL := m.Match(cmux.Any()) // Any means anything that is not yet matched. - -// Create your protocol servers. -grpcS := grpc.NewServer() -grpchello.RegisterGreeterServer(grpcS, &server{}) - -httpS := &http.Server{ - Handler: &helloHTTP1Handler{}, -} - -trpcS := rpc.NewServer() -trpcS.Register(&ExampleRPCRcvr{}) - -// Use the muxed listeners for your servers. -go grpcS.Serve(grpcL) -go httpS.Serve(httpL) -go trpcS.Accept(trpcL) - -// Start serving! -m.Serve() -``` - -Take a look at [other examples in the GoDoc](http://godoc.org/github.com/soheilhy/cmux/#pkg-examples). - -## Docs -* [GoDocs](https://godoc.org/github.com/soheilhy/cmux) - -## Performance -There is room for improvment but, since we are only matching -the very first bytes of a connection, the performance overheads on -long-lived connections (i.e., RPCs and pipelined HTTP streams) -is negligible. - -*TODO(soheil)*: Add benchmarks. - -## Limitations -* *TLS*: `net/http` uses a type assertion to identify TLS connections; since -cmux's lookahead-implementing connection wraps the underlying TLS connection, -this type assertion fails. -Because of that, you can serve HTTPS using cmux but `http.Request.TLS` -would not be set in your handlers. - -* *Different Protocols on The Same Connection*: `cmux` matches the connection -when it's accepted. For example, one connection can be either gRPC or REST, but -not both. That is, we assume that a client connection is either used for gRPC -or REST. - -* *Java gRPC Clients*: Java gRPC client blocks until it receives a SETTINGS -frame from the server. If you are using the Java client to connect to a cmux'ed -gRPC server please match with writers: -```go -grpcl := m.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc")) -``` - -# Copyright and License -Copyright 2016 The CMux Authors. All rights reserved. - -See [CONTRIBUTORS](https://github.com/soheilhy/cmux/blob/master/CONTRIBUTORS) -for the CMux Authors. Code is released under -[the Apache 2 license](https://github.com/soheilhy/cmux/blob/master/LICENSE). diff --git a/vendor/github.com/soheilhy/cmux/buffer.go b/vendor/github.com/soheilhy/cmux/buffer.go deleted file mode 100644 index f8cf30a1e66a..000000000000 --- a/vendor/github.com/soheilhy/cmux/buffer.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2016 The CMux Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -// implied. See the License for the specific language governing -// permissions and limitations under the License. - -package cmux - -import ( - "bytes" - "io" -) - -// bufferedReader is an optimized implementation of io.Reader that behaves like -// ``` -// io.MultiReader(bytes.NewReader(buffer.Bytes()), io.TeeReader(source, buffer)) -// ``` -// without allocating. -type bufferedReader struct { - source io.Reader - buffer bytes.Buffer - bufferRead int - bufferSize int - sniffing bool - lastErr error -} - -func (s *bufferedReader) Read(p []byte) (int, error) { - if s.bufferSize > s.bufferRead { - // If we have already read something from the buffer before, we return the - // same data and the last error if any. We need to immediately return, - // otherwise we may block for ever, if we try to be smart and call - // source.Read() seeking a little bit of more data. - bn := copy(p, s.buffer.Bytes()[s.bufferRead:s.bufferSize]) - s.bufferRead += bn - return bn, s.lastErr - } else if !s.sniffing && s.buffer.Cap() != 0 { - // We don't need the buffer anymore. - // Reset it to release the internal slice. - s.buffer = bytes.Buffer{} - } - - // If there is nothing more to return in the sniffed buffer, read from the - // source. - sn, sErr := s.source.Read(p) - if sn > 0 && s.sniffing { - s.lastErr = sErr - if wn, wErr := s.buffer.Write(p[:sn]); wErr != nil { - return wn, wErr - } - } - return sn, sErr -} - -func (s *bufferedReader) reset(snif bool) { - s.sniffing = snif - s.bufferRead = 0 - s.bufferSize = s.buffer.Len() -} diff --git a/vendor/github.com/soheilhy/cmux/cmux.go b/vendor/github.com/soheilhy/cmux/cmux.go deleted file mode 100644 index 5ba921e72dc0..000000000000 --- a/vendor/github.com/soheilhy/cmux/cmux.go +++ /dev/null @@ -1,307 +0,0 @@ -// Copyright 2016 The CMux Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -// implied. See the License for the specific language governing -// permissions and limitations under the License. - -package cmux - -import ( - "errors" - "fmt" - "io" - "net" - "sync" - "time" -) - -// Matcher matches a connection based on its content. -type Matcher func(io.Reader) bool - -// MatchWriter is a match that can also write response (say to do handshake). -type MatchWriter func(io.Writer, io.Reader) bool - -// ErrorHandler handles an error and returns whether -// the mux should continue serving the listener. -type ErrorHandler func(error) bool - -var _ net.Error = ErrNotMatched{} - -// ErrNotMatched is returned whenever a connection is not matched by any of -// the matchers registered in the multiplexer. -type ErrNotMatched struct { - c net.Conn -} - -func (e ErrNotMatched) Error() string { - return fmt.Sprintf("mux: connection %v not matched by an matcher", - e.c.RemoteAddr()) -} - -// Temporary implements the net.Error interface. -func (e ErrNotMatched) Temporary() bool { return true } - -// Timeout implements the net.Error interface. -func (e ErrNotMatched) Timeout() bool { return false } - -type errListenerClosed string - -func (e errListenerClosed) Error() string { return string(e) } -func (e errListenerClosed) Temporary() bool { return false } -func (e errListenerClosed) Timeout() bool { return false } - -// ErrListenerClosed is returned from muxListener.Accept when the underlying -// listener is closed. -var ErrListenerClosed = errListenerClosed("mux: listener closed") - -// ErrServerClosed is returned from muxListener.Accept when mux server is closed. -var ErrServerClosed = errors.New("mux: server closed") - -// for readability of readTimeout -var noTimeout time.Duration - -// New instantiates a new connection multiplexer. -func New(l net.Listener) CMux { - return &cMux{ - root: l, - bufLen: 1024, - errh: func(_ error) bool { return true }, - donec: make(chan struct{}), - readTimeout: noTimeout, - } -} - -// CMux is a multiplexer for network connections. -type CMux interface { - // Match returns a net.Listener that sees (i.e., accepts) only - // the connections matched by at least one of the matcher. - // - // The order used to call Match determines the priority of matchers. - Match(...Matcher) net.Listener - // MatchWithWriters returns a net.Listener that accepts only the - // connections that matched by at least of the matcher writers. - // - // Prefer Matchers over MatchWriters, since the latter can write on the - // connection before the actual handler. - // - // The order used to call Match determines the priority of matchers. - MatchWithWriters(...MatchWriter) net.Listener - // Serve starts multiplexing the listener. Serve blocks and perhaps - // should be invoked concurrently within a go routine. - Serve() error - // Closes cmux server and stops accepting any connections on listener - Close() - // HandleError registers an error handler that handles listener errors. - HandleError(ErrorHandler) - // sets a timeout for the read of matchers - SetReadTimeout(time.Duration) -} - -type matchersListener struct { - ss []MatchWriter - l muxListener -} - -type cMux struct { - root net.Listener - bufLen int - errh ErrorHandler - sls []matchersListener - readTimeout time.Duration - donec chan struct{} - mu sync.Mutex -} - -func matchersToMatchWriters(matchers []Matcher) []MatchWriter { - mws := make([]MatchWriter, 0, len(matchers)) - for _, m := range matchers { - cm := m - mws = append(mws, func(w io.Writer, r io.Reader) bool { - return cm(r) - }) - } - return mws -} - -func (m *cMux) Match(matchers ...Matcher) net.Listener { - mws := matchersToMatchWriters(matchers) - return m.MatchWithWriters(mws...) -} - -func (m *cMux) MatchWithWriters(matchers ...MatchWriter) net.Listener { - ml := muxListener{ - Listener: m.root, - connc: make(chan net.Conn, m.bufLen), - donec: make(chan struct{}), - } - m.sls = append(m.sls, matchersListener{ss: matchers, l: ml}) - return ml -} - -func (m *cMux) SetReadTimeout(t time.Duration) { - m.readTimeout = t -} - -func (m *cMux) Serve() error { - var wg sync.WaitGroup - - defer func() { - m.closeDoneChans() - wg.Wait() - - for _, sl := range m.sls { - close(sl.l.connc) - // Drain the connections enqueued for the listener. - for c := range sl.l.connc { - _ = c.Close() - } - } - }() - - for { - c, err := m.root.Accept() - if err != nil { - if !m.handleErr(err) { - return err - } - continue - } - - wg.Add(1) - go m.serve(c, m.donec, &wg) - } -} - -func (m *cMux) serve(c net.Conn, donec <-chan struct{}, wg *sync.WaitGroup) { - defer wg.Done() - - muc := newMuxConn(c) - if m.readTimeout > noTimeout { - _ = c.SetReadDeadline(time.Now().Add(m.readTimeout)) - } - for _, sl := range m.sls { - for _, s := range sl.ss { - matched := s(muc.Conn, muc.startSniffing()) - if matched { - muc.doneSniffing() - if m.readTimeout > noTimeout { - _ = c.SetReadDeadline(time.Time{}) - } - select { - case sl.l.connc <- muc: - case <-donec: - _ = c.Close() - } - return - } - } - } - - _ = c.Close() - err := ErrNotMatched{c: c} - if !m.handleErr(err) { - _ = m.root.Close() - } -} - -func (m *cMux) Close() { - m.closeDoneChans() -} - -func (m *cMux) closeDoneChans() { - m.mu.Lock() - defer m.mu.Unlock() - - select { - case <-m.donec: - // Already closed. Don't close again - default: - close(m.donec) - } - for _, sl := range m.sls { - select { - case <-sl.l.donec: - // Already closed. Don't close again - default: - close(sl.l.donec) - } - } -} - -func (m *cMux) HandleError(h ErrorHandler) { - m.errh = h -} - -func (m *cMux) handleErr(err error) bool { - if !m.errh(err) { - return false - } - - if ne, ok := err.(net.Error); ok { - return ne.Temporary() - } - - return false -} - -type muxListener struct { - net.Listener - connc chan net.Conn - donec chan struct{} -} - -func (l muxListener) Accept() (net.Conn, error) { - select { - case c, ok := <-l.connc: - if !ok { - return nil, ErrListenerClosed - } - return c, nil - case <-l.donec: - return nil, ErrServerClosed - } -} - -// MuxConn wraps a net.Conn and provides transparent sniffing of connection data. -type MuxConn struct { - net.Conn - buf bufferedReader -} - -func newMuxConn(c net.Conn) *MuxConn { - return &MuxConn{ - Conn: c, - buf: bufferedReader{source: c}, - } -} - -// From the io.Reader documentation: -// -// When Read encounters an error or end-of-file condition after -// successfully reading n > 0 bytes, it returns the number of -// bytes read. It may return the (non-nil) error from the same call -// or return the error (and n == 0) from a subsequent call. -// An instance of this general case is that a Reader returning -// a non-zero number of bytes at the end of the input stream may -// return either err == EOF or err == nil. The next Read should -// return 0, EOF. -func (m *MuxConn) Read(p []byte) (int, error) { - return m.buf.Read(p) -} - -func (m *MuxConn) startSniffing() io.Reader { - m.buf.reset(true) - return &m.buf -} - -func (m *MuxConn) doneSniffing() { - m.buf.reset(false) -} diff --git a/vendor/github.com/soheilhy/cmux/doc.go b/vendor/github.com/soheilhy/cmux/doc.go deleted file mode 100644 index aaa8f3158998..000000000000 --- a/vendor/github.com/soheilhy/cmux/doc.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2016 The CMux Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -// implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package cmux is a library to multiplex network connections based on -// their payload. Using cmux, you can serve different protocols from the -// same listener. -package cmux diff --git a/vendor/github.com/soheilhy/cmux/matchers.go b/vendor/github.com/soheilhy/cmux/matchers.go deleted file mode 100644 index 878ae98cc3cc..000000000000 --- a/vendor/github.com/soheilhy/cmux/matchers.go +++ /dev/null @@ -1,267 +0,0 @@ -// Copyright 2016 The CMux Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -// implied. See the License for the specific language governing -// permissions and limitations under the License. - -package cmux - -import ( - "bufio" - "crypto/tls" - "io" - "io/ioutil" - "net/http" - "strings" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/hpack" -) - -// Any is a Matcher that matches any connection. -func Any() Matcher { - return func(r io.Reader) bool { return true } -} - -// PrefixMatcher returns a matcher that matches a connection if it -// starts with any of the strings in strs. -func PrefixMatcher(strs ...string) Matcher { - pt := newPatriciaTreeString(strs...) - return pt.matchPrefix -} - -func prefixByteMatcher(list ...[]byte) Matcher { - pt := newPatriciaTree(list...) - return pt.matchPrefix -} - -var defaultHTTPMethods = []string{ - "OPTIONS", - "GET", - "HEAD", - "POST", - "PUT", - "DELETE", - "TRACE", - "CONNECT", -} - -// HTTP1Fast only matches the methods in the HTTP request. -// -// This matcher is very optimistic: if it returns true, it does not mean that -// the request is a valid HTTP response. If you want a correct but slower HTTP1 -// matcher, use HTTP1 instead. -func HTTP1Fast(extMethods ...string) Matcher { - return PrefixMatcher(append(defaultHTTPMethods, extMethods...)...) -} - -// TLS matches HTTPS requests. -// -// By default, any TLS handshake packet is matched. An optional whitelist -// of versions can be passed in to restrict the matcher, for example: -// TLS(tls.VersionTLS11, tls.VersionTLS12) -func TLS(versions ...int) Matcher { - if len(versions) == 0 { - versions = []int{ - tls.VersionSSL30, - tls.VersionTLS10, - tls.VersionTLS11, - tls.VersionTLS12, - } - } - prefixes := [][]byte{} - for _, v := range versions { - prefixes = append(prefixes, []byte{22, byte(v >> 8 & 0xff), byte(v & 0xff)}) - } - return prefixByteMatcher(prefixes...) -} - -const maxHTTPRead = 4096 - -// HTTP1 parses the first line or upto 4096 bytes of the request to see if -// the conection contains an HTTP request. -func HTTP1() Matcher { - return func(r io.Reader) bool { - br := bufio.NewReader(&io.LimitedReader{R: r, N: maxHTTPRead}) - l, part, err := br.ReadLine() - if err != nil || part { - return false - } - - _, _, proto, ok := parseRequestLine(string(l)) - if !ok { - return false - } - - v, _, ok := http.ParseHTTPVersion(proto) - return ok && v == 1 - } -} - -// grabbed from net/http. -func parseRequestLine(line string) (method, uri, proto string, ok bool) { - s1 := strings.Index(line, " ") - s2 := strings.Index(line[s1+1:], " ") - if s1 < 0 || s2 < 0 { - return - } - s2 += s1 + 1 - return line[:s1], line[s1+1 : s2], line[s2+1:], true -} - -// HTTP2 parses the frame header of the first frame to detect whether the -// connection is an HTTP2 connection. -func HTTP2() Matcher { - return hasHTTP2Preface -} - -// HTTP1HeaderField returns a matcher matching the header fields of the first -// request of an HTTP 1 connection. -func HTTP1HeaderField(name, value string) Matcher { - return func(r io.Reader) bool { - return matchHTTP1Field(r, name, func(gotValue string) bool { - return gotValue == value - }) - } -} - -// HTTP1HeaderFieldPrefix returns a matcher matching the header fields of the -// first request of an HTTP 1 connection. If the header with key name has a -// value prefixed with valuePrefix, this will match. -func HTTP1HeaderFieldPrefix(name, valuePrefix string) Matcher { - return func(r io.Reader) bool { - return matchHTTP1Field(r, name, func(gotValue string) bool { - return strings.HasPrefix(gotValue, valuePrefix) - }) - } -} - -// HTTP2HeaderField returns a matcher matching the header fields of the first -// headers frame. -func HTTP2HeaderField(name, value string) Matcher { - return func(r io.Reader) bool { - return matchHTTP2Field(ioutil.Discard, r, name, func(gotValue string) bool { - return gotValue == value - }) - } -} - -// HTTP2HeaderFieldPrefix returns a matcher matching the header fields of the -// first headers frame. If the header with key name has a value prefixed with -// valuePrefix, this will match. -func HTTP2HeaderFieldPrefix(name, valuePrefix string) Matcher { - return func(r io.Reader) bool { - return matchHTTP2Field(ioutil.Discard, r, name, func(gotValue string) bool { - return strings.HasPrefix(gotValue, valuePrefix) - }) - } -} - -// HTTP2MatchHeaderFieldSendSettings matches the header field and writes the -// settings to the server. Prefer HTTP2HeaderField over this one, if the client -// does not block on receiving a SETTING frame. -func HTTP2MatchHeaderFieldSendSettings(name, value string) MatchWriter { - return func(w io.Writer, r io.Reader) bool { - return matchHTTP2Field(w, r, name, func(gotValue string) bool { - return gotValue == value - }) - } -} - -// HTTP2MatchHeaderFieldPrefixSendSettings matches the header field prefix -// and writes the settings to the server. Prefer HTTP2HeaderFieldPrefix over -// this one, if the client does not block on receiving a SETTING frame. -func HTTP2MatchHeaderFieldPrefixSendSettings(name, valuePrefix string) MatchWriter { - return func(w io.Writer, r io.Reader) bool { - return matchHTTP2Field(w, r, name, func(gotValue string) bool { - return strings.HasPrefix(gotValue, valuePrefix) - }) - } -} - -func hasHTTP2Preface(r io.Reader) bool { - var b [len(http2.ClientPreface)]byte - last := 0 - - for { - n, err := r.Read(b[last:]) - if err != nil { - return false - } - - last += n - eq := string(b[:last]) == http2.ClientPreface[:last] - if last == len(http2.ClientPreface) { - return eq - } - if !eq { - return false - } - } -} - -func matchHTTP1Field(r io.Reader, name string, matches func(string) bool) (matched bool) { - req, err := http.ReadRequest(bufio.NewReader(r)) - if err != nil { - return false - } - - return matches(req.Header.Get(name)) -} - -func matchHTTP2Field(w io.Writer, r io.Reader, name string, matches func(string) bool) (matched bool) { - if !hasHTTP2Preface(r) { - return false - } - - done := false - framer := http2.NewFramer(w, r) - hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) { - if hf.Name == name { - done = true - if matches(hf.Value) { - matched = true - } - } - }) - for { - f, err := framer.ReadFrame() - if err != nil { - return false - } - - switch f := f.(type) { - case *http2.SettingsFrame: - // Sender acknoweldged the SETTINGS frame. No need to write - // SETTINGS again. - if f.IsAck() { - break - } - if err := framer.WriteSettings(); err != nil { - return false - } - case *http2.ContinuationFrame: - if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil { - return false - } - done = done || f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 - case *http2.HeadersFrame: - if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil { - return false - } - done = done || f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 - } - - if done { - return matched - } - } -} diff --git a/vendor/github.com/soheilhy/cmux/patricia.go b/vendor/github.com/soheilhy/cmux/patricia.go deleted file mode 100644 index c3e3d85bdeaf..000000000000 --- a/vendor/github.com/soheilhy/cmux/patricia.go +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright 2016 The CMux Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -// implied. See the License for the specific language governing -// permissions and limitations under the License. - -package cmux - -import ( - "bytes" - "io" -) - -// patriciaTree is a simple patricia tree that handles []byte instead of string -// and cannot be changed after instantiation. -type patriciaTree struct { - root *ptNode - maxDepth int // max depth of the tree. -} - -func newPatriciaTree(bs ...[]byte) *patriciaTree { - max := 0 - for _, b := range bs { - if max < len(b) { - max = len(b) - } - } - return &patriciaTree{ - root: newNode(bs), - maxDepth: max + 1, - } -} - -func newPatriciaTreeString(strs ...string) *patriciaTree { - b := make([][]byte, len(strs)) - for i, s := range strs { - b[i] = []byte(s) - } - return newPatriciaTree(b...) -} - -func (t *patriciaTree) matchPrefix(r io.Reader) bool { - buf := make([]byte, t.maxDepth) - n, _ := io.ReadFull(r, buf) - return t.root.match(buf[:n], true) -} - -func (t *patriciaTree) match(r io.Reader) bool { - buf := make([]byte, t.maxDepth) - n, _ := io.ReadFull(r, buf) - return t.root.match(buf[:n], false) -} - -type ptNode struct { - prefix []byte - next map[byte]*ptNode - terminal bool -} - -func newNode(strs [][]byte) *ptNode { - if len(strs) == 0 { - return &ptNode{ - prefix: []byte{}, - terminal: true, - } - } - - if len(strs) == 1 { - return &ptNode{ - prefix: strs[0], - terminal: true, - } - } - - p, strs := splitPrefix(strs) - n := &ptNode{ - prefix: p, - } - - nexts := make(map[byte][][]byte) - for _, s := range strs { - if len(s) == 0 { - n.terminal = true - continue - } - nexts[s[0]] = append(nexts[s[0]], s[1:]) - } - - n.next = make(map[byte]*ptNode) - for first, rests := range nexts { - n.next[first] = newNode(rests) - } - - return n -} - -func splitPrefix(bss [][]byte) (prefix []byte, rest [][]byte) { - if len(bss) == 0 || len(bss[0]) == 0 { - return prefix, bss - } - - if len(bss) == 1 { - return bss[0], [][]byte{{}} - } - - for i := 0; ; i++ { - var cur byte - eq := true - for j, b := range bss { - if len(b) <= i { - eq = false - break - } - - if j == 0 { - cur = b[i] - continue - } - - if cur != b[i] { - eq = false - break - } - } - - if !eq { - break - } - - prefix = append(prefix, cur) - } - - rest = make([][]byte, 0, len(bss)) - for _, b := range bss { - rest = append(rest, b[len(prefix):]) - } - - return prefix, rest -} - -func (n *ptNode) match(b []byte, prefix bool) bool { - l := len(n.prefix) - if l > 0 { - if l > len(b) { - l = len(b) - } - if !bytes.Equal(b[:l], n.prefix) { - return false - } - } - - if n.terminal && (prefix || len(n.prefix) == len(b)) { - return true - } - - if l >= len(b) { - return false - } - - nextN, ok := n.next[b[l]] - if !ok { - return false - } - - if l == len(b) { - b = b[l:l] - } else { - b = b[l+1:] - } - return nextN.match(b, prefix) -} diff --git a/vendor/google.golang.org/grpc/experimental/experimental.go b/vendor/google.golang.org/grpc/experimental/experimental.go new file mode 100644 index 000000000000..de7f13a2210e --- /dev/null +++ b/vendor/google.golang.org/grpc/experimental/experimental.go @@ -0,0 +1,65 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package experimental is a collection of experimental features that might +// have some rough edges to them. Housing experimental features in this package +// results in a user accessing these APIs as `experimental.Foo`, thereby making +// it explicit that the feature is experimental and using them in production +// code is at their own risk. +// +// All APIs in this package are experimental. +package experimental + +import ( + "google.golang.org/grpc" + "google.golang.org/grpc/internal" +) + +// WithRecvBufferPool returns a grpc.DialOption that configures the use of +// bufferPool for parsing incoming messages on a grpc.ClientConn. Depending on +// the application's workload, this could result in reduced memory allocation. +// +// If you are unsure about how to implement a memory pool but want to utilize +// one, begin with grpc.NewSharedBufferPool. +// +// Note: The shared buffer pool feature will not be active if any of the +// following options are used: WithStatsHandler, EnableTracing, or binary +// logging. In such cases, the shared buffer pool will be ignored. +// +// Note: It is not recommended to use the shared buffer pool when compression is +// enabled. +func WithRecvBufferPool(bufferPool grpc.SharedBufferPool) grpc.DialOption { + return internal.WithRecvBufferPool.(func(grpc.SharedBufferPool) grpc.DialOption)(bufferPool) +} + +// RecvBufferPool returns a grpc.ServerOption that configures the server to use +// the provided shared buffer pool for parsing incoming messages. Depending on +// the application's workload, this could result in reduced memory allocation. +// +// If you are unsure about how to implement a memory pool but want to utilize +// one, begin with grpc.NewSharedBufferPool. +// +// Note: The shared buffer pool feature will not be active if any of the +// following options are used: StatsHandler, EnableTracing, or binary logging. +// In such cases, the shared buffer pool will be ignored. +// +// Note: It is not recommended to use the shared buffer pool when compression is +// enabled. +func RecvBufferPool(bufferPool grpc.SharedBufferPool) grpc.ServerOption { + return internal.RecvBufferPool.(func(grpc.SharedBufferPool) grpc.ServerOption)(bufferPool) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 904dbc58f4d0..d414bbd3d538 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -927,7 +927,7 @@ github.com/gorilla/websocket # github.com/grafana/cloudflare-go v0.0.0-20230110200409-c627cf6792f2 ## explicit; go 1.17 github.com/grafana/cloudflare-go -# github.com/grafana/dskit v0.0.0-20240104111617-ea101a3b86eb +# github.com/grafana/dskit v0.0.0-20240528015923-27d7d41066d3 ## explicit; go 1.20 github.com/grafana/dskit/aws github.com/grafana/dskit/backoff @@ -976,7 +976,7 @@ github.com/grafana/dskit/user # github.com/grafana/go-gelf/v2 v2.0.1 ## explicit; go 1.17 github.com/grafana/go-gelf/v2/gelf -# github.com/grafana/gomemcache v0.0.0-20231204155601-7de47a8c3cb0 +# github.com/grafana/gomemcache v0.0.0-20240229205252-cd6a66d6fb56 ## explicit; go 1.18 github.com/grafana/gomemcache/memcache # github.com/grafana/jsonparser v0.0.0-20240209175146-098958973a2d @@ -1297,6 +1297,9 @@ github.com/pierrec/lz4/v4/internal/lz4block github.com/pierrec/lz4/v4/internal/lz4errors github.com/pierrec/lz4/v4/internal/lz4stream github.com/pierrec/lz4/v4/internal/xxh32 +# github.com/pires/go-proxyproto v0.7.0 +## explicit; go 1.18 +github.com/pires/go-proxyproto # github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c ## explicit; go 1.14 github.com/pkg/browser @@ -1468,9 +1471,6 @@ github.com/shurcooL/vfsgen # github.com/sirupsen/logrus v1.9.3 ## explicit; go 1.13 github.com/sirupsen/logrus -# github.com/soheilhy/cmux v0.1.5 -## explicit; go 1.11 -github.com/soheilhy/cmux # github.com/sony/gobreaker v0.5.0 ## explicit; go 1.12 github.com/sony/gobreaker @@ -1912,6 +1912,7 @@ google.golang.org/grpc/credentials/tls/certprovider/pemfile google.golang.org/grpc/encoding google.golang.org/grpc/encoding/gzip google.golang.org/grpc/encoding/proto +google.golang.org/grpc/experimental google.golang.org/grpc/grpclog google.golang.org/grpc/health/grpc_health_v1 google.golang.org/grpc/internal