From b26a6a60bbf0969c7ca5037d9ab6e485d336ced4 Mon Sep 17 00:00:00 2001 From: albertteoh Date: Fri, 4 Jun 2021 17:43:52 +1000 Subject: [PATCH 1/2] Add TLS support Signed-off-by: albertteoh --- pkg/prometheus/config/config.go | 11 ++-- plugin/metrics/prometheus/factory.go | 2 +- plugin/metrics/prometheus/factory_test.go | 10 ++-- .../metrics/prometheus/metricsstore/reader.go | 42 ++++++++++----- .../prometheus/metricsstore/reader_test.go | 54 +++++++++++++++++-- plugin/metrics/prometheus/options.go | 22 ++++++-- 6 files changed, 109 insertions(+), 32 deletions(-) diff --git a/pkg/prometheus/config/config.go b/pkg/prometheus/config/config.go index 494aa7bc2ff..3ac662adbc3 100644 --- a/pkg/prometheus/config/config.go +++ b/pkg/prometheus/config/config.go @@ -14,10 +14,15 @@ package config -import "time" +import ( + "time" + + "github.com/jaegertracing/jaeger/pkg/config/tlscfg" +) // Configuration describes the options to customize the storage behavior. type Configuration struct { - HostPort string `validate:"nonzero" mapstructure:"server"` - ConnectTimeout time.Duration `validate:"nonzero" mapstructure:"timeout"` + ServerURL string `validate:"nonzero" mapstructure:"server"` + ConnectTimeout time.Duration `validate:"nonzero" mapstructure:"timeout"` + TLS tlscfg.Options `mapstructure:"tls"` } diff --git a/plugin/metrics/prometheus/factory.go b/plugin/metrics/prometheus/factory.go index 8435be82bda..31802418a0d 100644 --- a/plugin/metrics/prometheus/factory.go +++ b/plugin/metrics/prometheus/factory.go @@ -55,5 +55,5 @@ func (f *Factory) Initialize(logger *zap.Logger) error { // CreateMetricsReader implements storage.MetricsFactory. func (f *Factory) CreateMetricsReader() (metricsstore.Reader, error) { - return prometheusstore.NewMetricsReader(f.logger, f.options.Primary.HostPort, f.options.Primary.ConnectTimeout) + return prometheusstore.NewMetricsReader(f.logger, f.options.Primary.Configuration) } diff --git a/plugin/metrics/prometheus/factory_test.go b/plugin/metrics/prometheus/factory_test.go index bfa801b0220..e3b028a020b 100644 --- a/plugin/metrics/prometheus/factory_test.go +++ b/plugin/metrics/prometheus/factory_test.go @@ -40,7 +40,7 @@ func TestPrometheusFactory(t *testing.T) { assert.NotNil(t, listener) defer listener.Close() - f.options.Primary.HostPort = listener.Addr().String() + f.options.Primary.ServerURL = "http://" + listener.Addr().String() reader, err := f.CreateMetricsReader() assert.NoError(t, err) @@ -49,20 +49,20 @@ func TestPrometheusFactory(t *testing.T) { func TestWithDefaultConfiguration(t *testing.T) { f := NewFactory() - assert.Equal(t, f.options.Primary.HostPort, defaultServerHostPort) - assert.Equal(t, f.options.Primary.ConnectTimeout, defaultConnectTimeout) + assert.Equal(t, f.options.Primary.ServerURL, "http://localhost:9090") + assert.Equal(t, f.options.Primary.ConnectTimeout, 30*time.Second) } func TestWithConfiguration(t *testing.T) { f := NewFactory() v, command := config.Viperize(f.AddFlags) err := command.ParseFlags([]string{ - "--prometheus.host-port=localhost:1234", + "--prometheus.server-url=http://localhost:1234", "--prometheus.connect-timeout=5s", }) require.NoError(t, err) f.InitFromViper(v) - assert.Equal(t, f.options.Primary.HostPort, "localhost:1234") + assert.Equal(t, f.options.Primary.ServerURL, "http://localhost:1234") assert.Equal(t, f.options.Primary.ConnectTimeout, 5*time.Second) } diff --git a/plugin/metrics/prometheus/metricsstore/reader.go b/plugin/metrics/prometheus/metricsstore/reader.go index 99f5cde4c58..3e6be7a54d3 100644 --- a/plugin/metrics/prometheus/metricsstore/reader.go +++ b/plugin/metrics/prometheus/metricsstore/reader.go @@ -30,6 +30,7 @@ import ( promapi "github.com/prometheus/client_golang/api/prometheus/v1" "go.uber.org/zap" + "github.com/jaegertracing/jaeger/pkg/prometheus/config" "github.com/jaegertracing/jaeger/plugin/metrics/prometheus/metricsstore/dbmodel" "github.com/jaegertracing/jaeger/proto-gen/api_v2/metrics" "github.com/jaegertracing/jaeger/storage/metricsstore" @@ -72,20 +73,13 @@ type ( ) // NewMetricsReader returns a new MetricsReader. -func NewMetricsReader(logger *zap.Logger, hostPort string, connectTimeout time.Duration) (*MetricsReader, error) { - // KeepAlive and TLSHandshake timeouts are kept to existing Prometheus client's - // DefaultRoundTripper to simplify user configuration and may be made configurable when required. - roundTripper := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: connectTimeout, - KeepAlive: 30 * time.Second, - }).DialContext, - TLSHandshakeTimeout: 10 * time.Second, +func NewMetricsReader(logger *zap.Logger, cfg config.Configuration) (*MetricsReader, error) { + roundTripper, err := getHTTPRoundTripper(&cfg, logger) + if err != nil { + return nil, err } - client, err := api.NewClient(api.Config{ - Address: "http://" + hostPort, + Address: cfg.ServerURL, RoundTripper: roundTripper, }) if err != nil { @@ -95,7 +89,7 @@ func NewMetricsReader(logger *zap.Logger, hostPort string, connectTimeout time.D client: promapi.NewAPI(client), logger: logger, } - logger.Info("Prometheus reader initialized", zap.String("addr", hostPort)) + logger.Info("Prometheus reader initialized", zap.String("addr", cfg.ServerURL)) return mr, nil } @@ -247,3 +241,25 @@ func logErrorToSpan(span opentracing.Span, err error) { ottag.Error.Set(span, true) span.LogFields(otlog.Error(err)) } + +func getHTTPRoundTripper(c *config.Configuration, logger *zap.Logger) (http.RoundTripper, error) { + if !c.TLS.Enabled { + return nil, nil + } + ctlsConfig, err := c.TLS.Config(logger) + if err != nil { + return nil, err + } + + // KeepAlive and TLSHandshake timeouts are kept to existing Prometheus client's + // DefaultRoundTripper to simplify user configuration and may be made configurable when required. + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: c.ConnectTimeout, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: ctlsConfig, + }, nil +} diff --git a/plugin/metrics/prometheus/metricsstore/reader_test.go b/plugin/metrics/prometheus/metricsstore/reader_test.go index 16c38132c11..5cb61f72b38 100644 --- a/plugin/metrics/prometheus/metricsstore/reader_test.go +++ b/plugin/metrics/prometheus/metricsstore/reader_test.go @@ -31,6 +31,8 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/zap" + "github.com/jaegertracing/jaeger/pkg/config/tlscfg" + "github.com/jaegertracing/jaeger/pkg/prometheus/config" "github.com/jaegertracing/jaeger/proto-gen/api_v2/metrics" "github.com/jaegertracing/jaeger/storage/metricsstore" ) @@ -52,14 +54,20 @@ const defaultTimeout = 30 * time.Second func TestNewMetricsReaderValidAddress(t *testing.T) { logger := zap.NewNop() - reader, err := NewMetricsReader(logger, "localhost:1234", defaultTimeout) + reader, err := NewMetricsReader(logger, config.Configuration{ + ServerURL: "http://localhost:1234", + ConnectTimeout: defaultTimeout, + }) require.NoError(t, err) assert.NotNil(t, reader) } func TestNewMetricsReaderInvalidAddress(t *testing.T) { logger := zap.NewNop() - reader, err := NewMetricsReader(logger, "\n", defaultTimeout) + reader, err := NewMetricsReader(logger, config.Configuration{ + ServerURL: "\n", + ConnectTimeout: defaultTimeout, + }) require.Error(t, err) assert.Contains(t, err.Error(), "failed to initialize prometheus client") assert.Nil(t, reader) @@ -72,7 +80,10 @@ func TestGetMinStepDuration(t *testing.T) { require.NoError(t, err) assert.NotNil(t, listener) - reader, err := NewMetricsReader(logger, listener.Addr().String(), defaultTimeout) + reader, err := NewMetricsReader(logger, config.Configuration{ + ServerURL: "http://" + listener.Addr().String(), + ConnectTimeout: defaultTimeout, + }) require.NoError(t, err) minStep, err := reader.GetMinStepDuration(context.Background(), ¶ms) @@ -102,7 +113,10 @@ func TestMetricsServerError(t *testing.T) { logger := zap.NewNop() address := mockPrometheus.Listener.Addr().String() - reader, err := NewMetricsReader(logger, address, defaultTimeout) + reader, err := NewMetricsReader(logger, config.Configuration{ + ServerURL: "http://" + address, + ConnectTimeout: defaultTimeout, + }) require.NoError(t, err) m, err := reader.GetCallRates(context.Background(), ¶ms) @@ -299,6 +313,33 @@ func TestWarningResponse(t *testing.T) { assert.NotNil(t, m) } +func TestTLSEnabledMetricsReader(t *testing.T) { + logger := zap.NewNop() + reader, err := NewMetricsReader(logger, config.Configuration{ + ServerURL: "https://localhost:1234", + ConnectTimeout: defaultTimeout, + TLS: tlscfg.Options{ + Enabled: true, + }, + }) + require.NoError(t, err) + assert.NotNil(t, reader) +} + +func TestInvalidCertFile(t *testing.T) { + logger := zap.NewNop() + reader, err := NewMetricsReader(logger, config.Configuration{ + ServerURL: "https://localhost:1234", + ConnectTimeout: defaultTimeout, + TLS: tlscfg.Options{ + Enabled: true, + CAPath: "foo", + }, + }) + require.Error(t, err) + assert.Nil(t, reader) +} + func startMockPrometheusServer(t *testing.T, wantPromQlQuery string, wantWarnings []string) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if len(wantWarnings) > 0 { @@ -356,7 +397,10 @@ func prepareMetricsReaderAndServer(t *testing.T, wantPromQlQuery string, wantWar logger := zap.NewNop() address := mockPrometheus.Listener.Addr().String() - reader, err := NewMetricsReader(logger, address, defaultTimeout) + reader, err := NewMetricsReader(logger, config.Configuration{ + ServerURL: "http://" + address, + ConnectTimeout: defaultTimeout, + }) require.NoError(t, err) return reader, mockPrometheus } diff --git a/plugin/metrics/prometheus/options.go b/plugin/metrics/prometheus/options.go index 8d346ef6c36..76afd5252c5 100644 --- a/plugin/metrics/prometheus/options.go +++ b/plugin/metrics/prometheus/options.go @@ -21,14 +21,15 @@ import ( "github.com/spf13/viper" + "github.com/jaegertracing/jaeger/pkg/config/tlscfg" "github.com/jaegertracing/jaeger/pkg/prometheus/config" ) const ( - suffixHostPort = ".host-port" + suffixServerURL = ".server-url" suffixConnectTimeout = ".connect-timeout" - defaultServerHostPort = "localhost:9090" + defaultServerURL = "http://localhost:9090" defaultConnectTimeout = 30 * time.Second ) @@ -45,7 +46,7 @@ type Options struct { // NewOptions creates a new Options struct. func NewOptions(primaryNamespace string) *Options { defaultConfig := config.Configuration{ - HostPort: defaultServerHostPort, + ServerURL: defaultServerURL, ConnectTimeout: defaultConnectTimeout, } @@ -60,15 +61,26 @@ func NewOptions(primaryNamespace string) *Options { // AddFlags from this storage to the CLI. func (opt *Options) AddFlags(flagSet *flag.FlagSet) { nsConfig := &opt.Primary - flagSet.String(nsConfig.namespace+suffixHostPort, defaultServerHostPort, "The host:port of the Prometheus query service.") + flagSet.String(nsConfig.namespace+suffixServerURL, defaultServerURL, "The Prometheus server's URL, must include the protocol scheme e.g. http://localhost:9090") flagSet.Duration(nsConfig.namespace+suffixConnectTimeout, defaultConnectTimeout, "The period to wait for a connection to Prometheus when executing queries.") + + nsConfig.getTLSFlagsConfig().AddFlags(flagSet) } // InitFromViper initializes the options struct with values from Viper. func (opt *Options) InitFromViper(v *viper.Viper) { cfg := &opt.Primary - cfg.HostPort = stripWhiteSpace(v.GetString(cfg.namespace + suffixHostPort)) + cfg.ServerURL = stripWhiteSpace(v.GetString(cfg.namespace + suffixServerURL)) cfg.ConnectTimeout = v.GetDuration(cfg.namespace + suffixConnectTimeout) + cfg.TLS = cfg.getTLSFlagsConfig().InitFromViper(v) +} + +func (config *namespaceConfig) getTLSFlagsConfig() tlscfg.ClientFlagsConfig { + return tlscfg.ClientFlagsConfig{ + Prefix: config.namespace, + ShowEnabled: true, + ShowServerName: true, + } } // stripWhiteSpace removes all whitespace characters from a string. From e9d671294318c55fabd5f201754db3e6de678ef4 Mon Sep 17 00:00:00 2001 From: albertteoh Date: Sat, 5 Jun 2021 06:58:26 +1000 Subject: [PATCH 2/2] Address review comments Signed-off-by: albertteoh --- pkg/prometheus/config/config.go | 6 +-- .../metrics/prometheus/metricsstore/reader.go | 14 +++---- .../prometheus/metricsstore/reader_test.go | 37 +++++++++++++------ 3 files changed, 36 insertions(+), 21 deletions(-) diff --git a/pkg/prometheus/config/config.go b/pkg/prometheus/config/config.go index 3ac662adbc3..6560f51fd9a 100644 --- a/pkg/prometheus/config/config.go +++ b/pkg/prometheus/config/config.go @@ -22,7 +22,7 @@ import ( // Configuration describes the options to customize the storage behavior. type Configuration struct { - ServerURL string `validate:"nonzero" mapstructure:"server"` - ConnectTimeout time.Duration `validate:"nonzero" mapstructure:"timeout"` - TLS tlscfg.Options `mapstructure:"tls"` + ServerURL string + ConnectTimeout time.Duration + TLS tlscfg.Options } diff --git a/plugin/metrics/prometheus/metricsstore/reader.go b/plugin/metrics/prometheus/metricsstore/reader.go index 3e6be7a54d3..257d1d1616a 100644 --- a/plugin/metrics/prometheus/metricsstore/reader.go +++ b/plugin/metrics/prometheus/metricsstore/reader.go @@ -16,6 +16,7 @@ package metricsstore import ( "context" + "crypto/tls" "fmt" "net" "net/http" @@ -242,13 +243,12 @@ func logErrorToSpan(span opentracing.Span, err error) { span.LogFields(otlog.Error(err)) } -func getHTTPRoundTripper(c *config.Configuration, logger *zap.Logger) (http.RoundTripper, error) { - if !c.TLS.Enabled { - return nil, nil - } - ctlsConfig, err := c.TLS.Config(logger) - if err != nil { - return nil, err +func getHTTPRoundTripper(c *config.Configuration, logger *zap.Logger) (rt http.RoundTripper, err error) { + var ctlsConfig *tls.Config + if c.TLS.Enabled { + if ctlsConfig, err = c.TLS.Config(logger); err != nil { + return nil, err + } } // KeepAlive and TLSHandshake timeouts are kept to existing Prometheus client's diff --git a/plugin/metrics/prometheus/metricsstore/reader_test.go b/plugin/metrics/prometheus/metricsstore/reader_test.go index 5cb61f72b38..560a5786917 100644 --- a/plugin/metrics/prometheus/metricsstore/reader_test.go +++ b/plugin/metrics/prometheus/metricsstore/reader_test.go @@ -313,17 +313,32 @@ func TestWarningResponse(t *testing.T) { assert.NotNil(t, m) } -func TestTLSEnabledMetricsReader(t *testing.T) { - logger := zap.NewNop() - reader, err := NewMetricsReader(logger, config.Configuration{ - ServerURL: "https://localhost:1234", - ConnectTimeout: defaultTimeout, - TLS: tlscfg.Options{ - Enabled: true, - }, - }) - require.NoError(t, err) - assert.NotNil(t, reader) +func TestGetRoundTripper(t *testing.T) { + for _, tc := range []struct { + name string + tlsEnabled bool + }{ + {"tls tlsEnabled", true}, + {"tls disabled", false}, + } { + t.Run(tc.name, func(t *testing.T) { + logger := zap.NewNop() + rt, err := getHTTPRoundTripper(&config.Configuration{ + ServerURL: "https://localhost:1234", + ConnectTimeout: 9 * time.Millisecond, + TLS: tlscfg.Options{ + Enabled: tc.tlsEnabled, + }, + }, logger) + require.NoError(t, err) + assert.IsType(t, &http.Transport{}, rt) + if tc.tlsEnabled { + assert.NotNil(t, rt.(*http.Transport).TLSClientConfig) + } else { + assert.Nil(t, rt.(*http.Transport).TLSClientConfig) + } + }) + } } func TestInvalidCertFile(t *testing.T) {