From 690ad5427677734fe1b6eb709af9672f5b1a8239 Mon Sep 17 00:00:00 2001 From: Alexander Mays Date: Thu, 12 May 2022 18:28:26 -0500 Subject: [PATCH] Add RDS IAM auth plugin for SQL drivers (#2830) Signed-off-by: Alexander Mays Added: A new environment variable to the CLI: SQL_AUTH_PLUGIN=rds-iam-auth A new flag arg to the CLI: --sql-auth-plugin rds-iam-auth 2 new docker template variables: authPlugin: {{ default .Env.SQL_AUTH_PLUGIN "" }} and authPlugin: {{ default .Env.SQL_VIS_AUTH_PLUGIN "" }} A new SQL configuration attribute authPlugin --- Makefile | 2 +- common/config/config.go | 9 + .../sql/sqlplugin/auth/auth_plugin.go | 57 +++++++ .../sql/sqlplugin/auth/auth_plugin_mock.go | 75 +++++++++ .../sql/sqlplugin/auth/rds_auth_plugin.go | 159 ++++++++++++++++++ .../sqlplugin/auth/rds_auth_plugin_test.go | 100 +++++++++++ .../sql/sqlplugin/mysql/session/session.go | 23 +++ .../sqlplugin/mysql/session/session_test.go | 51 ++++++ .../sqlplugin/postgresql/session/session.go | 23 +++ docker/config_template.yaml | 4 + go.mod | 15 ++ go.sum | 24 +++ tools/common/schema/types.go | 4 + tools/sql/handler.go | 11 ++ tools/sql/main.go | 6 + 15 files changed, 562 insertions(+), 1 deletion(-) create mode 100644 common/persistence/sql/sqlplugin/auth/auth_plugin.go create mode 100644 common/persistence/sql/sqlplugin/auth/auth_plugin_mock.go create mode 100644 common/persistence/sql/sqlplugin/auth/rds_auth_plugin.go create mode 100644 common/persistence/sql/sqlplugin/auth/rds_auth_plugin_test.go diff --git a/Makefile b/Makefile index 8f1eda9044b..642fe2b86a1 100644 --- a/Makefile +++ b/Makefile @@ -423,7 +423,7 @@ start-cdc-other: temporal-server ./temporal-server --zone other start ##### Mocks ##### -AWS_SDK_VERSION := $(lastword $(shell grep "github.com/aws/aws-sdk-go" go.mod)) +AWS_SDK_VERSION := $(lastword $(shell grep "github.com/aws/aws-sdk-go v1" go.mod)) external-mocks: @printf $(COLOR) "Generate external libraries mocks..." @mockgen -copyright_file ./LICENSE -package mocks -source $(GOPATH)/pkg/mod/github.com/aws/aws-sdk-go@$(AWS_SDK_VERSION)/service/s3/s3iface/interface.go | grep -v -e "^// Source: .*" > common/archiver/s3store/mocks/S3API.go diff --git a/common/config/config.go b/common/config/config.go index 6e9aeb3971c..adb38371f28 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -320,6 +320,15 @@ type ( TaskScanPartitions int `yaml:"taskScanPartitions"` // TLS is the configuration for TLS connections TLS *auth.TLS `yaml:"tls"` + // AuthPlugin is the configuration for a SQL authentication plugin + // - currently drivers 'mysql' and 'postgres' support 'rds-iam-auth' + AuthPlugin *SQLAuthPlugin `yaml:"authPlugin"` + } + + // SQLAuthPlugin determines which sql auth plugin is invoked for each new SQL session + SQLAuthPlugin struct { + Plugin string `yaml:"plugin"` + Timeout time.Duration `yaml:"timeout"` } // CustomDatastoreConfig is the configuration for connecting to a custom datastore that is not supported by temporal core diff --git a/common/persistence/sql/sqlplugin/auth/auth_plugin.go b/common/persistence/sql/sqlplugin/auth/auth_plugin.go new file mode 100644 index 00000000000..69725d639a3 --- /dev/null +++ b/common/persistence/sql/sqlplugin/auth/auth_plugin.go @@ -0,0 +1,57 @@ +// The MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:generate mockgen -copyright_file ../../../../../LICENSE -package $GOPACKAGE -source $GOFILE -destination auth_plugin_mock.go +package auth + +import ( + "context" + "errors" + + "go.temporal.io/server/common/config" +) + +var ( + ErrInvalidAuthPluginName = errors.New("auth_plugin: invalid auth plugin requested") + plugins = map[string]AuthPlugin{} +) + +type ( + // AuthPlugin interface for mutating SQL connection parameters + AuthPlugin interface { + // GetConfig returns a mutated SQL config + GetConfig(context.Context, *config.SQL) (*config.SQL, error) + } +) + +// RegisterPlugin adds an auth plugin to the plugin registry +// it is only safe to use from a package init function +func RegisterPlugin(name string, plugin AuthPlugin) { + plugins[name] = plugin +} + +func LookupPlugin(name string) (AuthPlugin, error) { + plugin, ok := plugins[name] + if !ok { + return nil, ErrInvalidAuthPluginName + } + + return plugin, nil +} diff --git a/common/persistence/sql/sqlplugin/auth/auth_plugin_mock.go b/common/persistence/sql/sqlplugin/auth/auth_plugin_mock.go new file mode 100644 index 00000000000..732583a7870 --- /dev/null +++ b/common/persistence/sql/sqlplugin/auth/auth_plugin_mock.go @@ -0,0 +1,75 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: auth_plugin.go + +// Package auth is a generated GoMock package. +package auth + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + config "go.temporal.io/server/common/config" +) + +// MockAuthPlugin is a mock of AuthPlugin interface. +type MockAuthPlugin struct { + ctrl *gomock.Controller + recorder *MockAuthPluginMockRecorder +} + +// MockAuthPluginMockRecorder is the mock recorder for MockAuthPlugin. +type MockAuthPluginMockRecorder struct { + mock *MockAuthPlugin +} + +// NewMockAuthPlugin creates a new mock instance. +func NewMockAuthPlugin(ctrl *gomock.Controller) *MockAuthPlugin { + mock := &MockAuthPlugin{ctrl: ctrl} + mock.recorder = &MockAuthPluginMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAuthPlugin) EXPECT() *MockAuthPluginMockRecorder { + return m.recorder +} + +// GetConfig mocks base method. +func (m *MockAuthPlugin) GetConfig(arg0 context.Context, arg1 *config.SQL) (*config.SQL, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetConfig", arg0, arg1) + ret0, _ := ret[0].(*config.SQL) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetConfig indicates an expected call of GetConfig. +func (mr *MockAuthPluginMockRecorder) GetConfig(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfig", reflect.TypeOf((*MockAuthPlugin)(nil).GetConfig), arg0, arg1) +} diff --git a/common/persistence/sql/sqlplugin/auth/rds_auth_plugin.go b/common/persistence/sql/sqlplugin/auth/rds_auth_plugin.go new file mode 100644 index 00000000000..0cd75f9d34b --- /dev/null +++ b/common/persistence/sql/sqlplugin/auth/rds_auth_plugin.go @@ -0,0 +1,159 @@ +// The MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package auth + +import ( + "context" + "encoding/base64" + "errors" + "io/ioutil" + "net/http" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + AWSConfig "github.com/aws/aws-sdk-go-v2/config" + AWSAuth "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + + "go.temporal.io/server/common/auth" + "go.temporal.io/server/common/config" +) + +const defaultTimeout = time.Second * 10 +const rdsCaUrl = "https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem" + +var rdsAuthFn = AWSAuth.BuildAuthToken + +func init() { + RegisterPlugin("rds-iam-auth", NewRDSAuthPlugin(nil)) +} + +func fetchRdsCA(ctx context.Context) (string, error) { + ctx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", rdsCaUrl, nil) + if err != nil { + return "", err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + + defer resp.Body.Close() + pem, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(pem), nil +} + +type RDSAuthPlugin struct { + awsConfig *aws.Config + rdsPemBundle string + initRdsPemBundle sync.Once +} + +func NewRDSAuthPlugin(awsConfig *aws.Config) AuthPlugin { + return &RDSAuthPlugin{ + awsConfig: awsConfig, + } +} + +func (plugin *RDSAuthPlugin) getToken(ctx context.Context, addr string, region string, user string, credentials aws.CredentialsProvider) (string, error) { + reqCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + + return rdsAuthFn(reqCtx, addr, region, user, credentials) +} + +func (plugin *RDSAuthPlugin) resolveAwsConfig(ctx context.Context) (*aws.Config, error) { + if plugin.awsConfig != nil { + return plugin.awsConfig, nil + } + + reqCtx, cancel := context.WithTimeout(ctx, defaultTimeout) + defer cancel() + + cfg, err := AWSConfig.LoadDefaultConfig(reqCtx) + if err != nil { + return nil, err + } + + return &cfg, nil +} + +func (plugin *RDSAuthPlugin) GetConfig(ctx context.Context, cfg *config.SQL) (*config.SQL, error) { + awsCfg, err := plugin.resolveAwsConfig(ctx) + if err != nil { + return nil, err + } + + token, err := plugin.getToken(ctx, cfg.ConnectAddr, awsCfg.Region, cfg.User, awsCfg.Credentials) + if err != nil { + return nil, err + } + + cfg.Password = token + cfg.ConnectProtocol = "tcp" + + if cfg.ConnectAttributes == nil { + cfg.ConnectAttributes = map[string]string{} + } + + // mysql requires this plugin to use the token as a password + if cfg.PluginName == "mysql" { + cfg.ConnectAttributes["allowCleartextPasswords"] = "true" + } + + // if TLS is not configured, we default to the RDS CA + // this is required for mysql to send cleartext passwords + if cfg.TLS == nil { + var fetchErr error + plugin.initRdsPemBundle.Do(func() { + ca, err := fetchRdsCA(ctx) + if err != nil { + fetchErr = err + return + } + + plugin.rdsPemBundle = ca + }) + + if fetchErr != nil { + return nil, fetchErr + } + + if plugin.rdsPemBundle == "" { + return nil, errors.New("rds_auth_plugin: unable to retrieve rds ca certificates") + } + + cfg.TLS = &auth.TLS{ + Enabled: true, + CaData: plugin.rdsPemBundle, + } + } + + return cfg, nil +} diff --git a/common/persistence/sql/sqlplugin/auth/rds_auth_plugin_test.go b/common/persistence/sql/sqlplugin/auth/rds_auth_plugin_test.go new file mode 100644 index 00000000000..d899629a52d --- /dev/null +++ b/common/persistence/sql/sqlplugin/auth/rds_auth_plugin_test.go @@ -0,0 +1,100 @@ +// The MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package auth + +import ( + "context" + "sync" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/suite" + TLS "go.temporal.io/server/common/auth" + "go.temporal.io/server/common/config" +) + +type ( + rdsAuthPluginTestSuite struct { + suite.Suite + controller *gomock.Controller + } +) + +func TestSessionTestSuite(t *testing.T) { + s := new(rdsAuthPluginTestSuite) + suite.Run(t, s) +} + +func (s *rdsAuthPluginTestSuite) SetupSuite() { + +} + +func (s *rdsAuthPluginTestSuite) TearDownSuite() { + +} + +func (s *rdsAuthPluginTestSuite) SetupTest() { + s.controller = gomock.NewController(s.T()) +} + +func (s *rdsAuthPluginTestSuite) TearDownTest() { + s.controller.Finish() +} + +func (s *rdsAuthPluginTestSuite) TestRdsAuthPlugin() { + originalFn := rdsAuthFn + defer func() { + rdsAuthFn = originalFn + }() + + rdsAuthFn = func(ctx context.Context, endpoint, region, dbUser string, creds aws.CredentialsProvider, optFns ...func(options *auth.BuildAuthTokenOptions)) (string, error) { + return "token", nil + } + syncRdsPemBundle := sync.Once{} + syncRdsPemBundle.Do(func() {}) + + plugin := &RDSAuthPlugin{ + awsConfig: aws.NewConfig(), + rdsPemBundle: "test", + initRdsPemBundle: syncRdsPemBundle, + } + + cfg, err := plugin.GetConfig(context.TODO(), &config.SQL{ + PluginName: "mysql", + ConnectAttributes: map[string]string{}, + }) + + s.Equal(nil, err) + s.Equal(&config.SQL{ + PluginName: "mysql", + Password: "token", + ConnectProtocol: "tcp", + ConnectAttributes: map[string]string{ + "allowCleartextPasswords": "true", + }, + TLS: &TLS.TLS{ + Enabled: true, + CaData: "test", + }, + }, cfg) +} diff --git a/common/persistence/sql/sqlplugin/mysql/session/session.go b/common/persistence/sql/sqlplugin/mysql/session/session.go index b9490d5043c..1ee80d233df 100644 --- a/common/persistence/sql/sqlplugin/mysql/session/session.go +++ b/common/persistence/sql/sqlplugin/mysql/session/session.go @@ -25,11 +25,13 @@ package session import ( + "context" "crypto/tls" "crypto/x509" "fmt" "os" "strings" + "time" "github.com/go-sql-driver/mysql" "github.com/iancoleman/strcase" @@ -37,6 +39,7 @@ import ( "go.temporal.io/server/common/auth" "go.temporal.io/server/common/config" + SQLAuth "go.temporal.io/server/common/persistence/sql/sqlplugin/auth" "go.temporal.io/server/common/resolver" ) @@ -63,6 +66,26 @@ func NewSession( cfg *config.SQL, resolver resolver.ServiceResolver, ) (*Session, error) { + if cfg.AuthPlugin != nil && cfg.AuthPlugin.Plugin != "" { + authPlugin, err := SQLAuth.LookupPlugin(cfg.AuthPlugin.Plugin) + if err != nil { + return nil, err + } + + timeout := cfg.AuthPlugin.Timeout + if timeout == 0 { + timeout = time.Duration(time.Second * 10) + } + + ctx, cancel := context.WithTimeout(context.TODO(), timeout) + defer cancel() + + cfg, err = authPlugin.GetConfig(ctx, cfg) + if err != nil { + return nil, err + } + } + db, err := createConnection(cfg, resolver) if err != nil { return nil, err diff --git a/common/persistence/sql/sqlplugin/mysql/session/session_test.go b/common/persistence/sql/sqlplugin/mysql/session/session_test.go index 6343daecba8..66534d60260 100644 --- a/common/persistence/sql/sqlplugin/mysql/session/session_test.go +++ b/common/persistence/sql/sqlplugin/mysql/session/session_test.go @@ -25,6 +25,7 @@ package session import ( + "errors" "net/url" "strings" "testing" @@ -33,6 +34,7 @@ import ( "github.com/stretchr/testify/suite" "go.temporal.io/server/common/config" + "go.temporal.io/server/common/persistence/sql/sqlplugin/auth" "go.temporal.io/server/common/resolver" ) @@ -152,6 +154,55 @@ func (s *sessionTestSuite) TestBuildDSN() { } } +func (s *sessionTestSuite) TestAuthPlugins() { + testCases := []struct { + in config.SQL + expectedError error + }{ + { + in: config.SQL{ + User: "test", + Password: "pass", + ConnectProtocol: "tcp", + ConnectAddr: "192.168.0.1:3306", + DatabaseName: "db1", + AuthPlugin: &config.SQLAuthPlugin{ + Plugin: "unsupported", + }, + }, + expectedError: auth.ErrInvalidAuthPluginName, + }, + { + in: config.SQL{ + User: "test", + Password: "pass", + ConnectProtocol: "tcp", + ConnectAddr: "192.168.0.1:3306", + DatabaseName: "db1", + AuthPlugin: &config.SQLAuthPlugin{ + Plugin: "rds-iam-auth", + }, + }, + expectedError: errors.New("NOOP"), + }, + } + + rdsPlugin, _ := auth.LookupPlugin("rds-iam-auth") + rdsMock := auth.NewMockAuthPlugin(s.controller) + auth.RegisterPlugin("rds-iam-auth", rdsMock) + defer func() { + auth.RegisterPlugin("rds-iam-auth", rdsPlugin) + }() + + // return a noop error to avoid having to mock the SQL connection + rdsMock.EXPECT().GetConfig(gomock.Any(), gomock.Any()).Return(nil, errors.New("NOOP")).AnyTimes() + + for _, tc := range testCases { + _, err := NewSession(&tc.in, nil) + s.Equal(tc.expectedError, err) + } +} + func buildExpectedURLParams(attrs map[string]string, isolationKey string, isolationValue string) url.Values { result := make(map[string][]string, len(dsnAttrOverrides)+len(attrs)+1) for k, v := range attrs { diff --git a/common/persistence/sql/sqlplugin/postgresql/session/session.go b/common/persistence/sql/sqlplugin/postgresql/session/session.go index e3713fa6887..b259adb3202 100644 --- a/common/persistence/sql/sqlplugin/postgresql/session/session.go +++ b/common/persistence/sql/sqlplugin/postgresql/session/session.go @@ -25,14 +25,17 @@ package session import ( + "context" "fmt" "net/url" "strings" + "time" "github.com/iancoleman/strcase" "github.com/jmoiron/sqlx" "go.temporal.io/server/common/config" + SQLAuth "go.temporal.io/server/common/persistence/sql/sqlplugin/auth" "go.temporal.io/server/common/resolver" ) @@ -62,6 +65,26 @@ func NewSession( cfg *config.SQL, resolver resolver.ServiceResolver, ) (*Session, error) { + if cfg.AuthPlugin != nil && cfg.AuthPlugin.Plugin != "" { + authPlugin, err := SQLAuth.LookupPlugin(cfg.AuthPlugin.Plugin) + if err != nil { + return nil, err + } + + timeout := cfg.AuthPlugin.Timeout + if timeout == 0 { + timeout = time.Duration(time.Second * 10) + } + + ctx, cancel := context.WithTimeout(context.TODO(), timeout) + defer cancel() + + cfg, err = authPlugin.GetConfig(ctx, cfg) + if err != nil { + return nil, err + } + } + db, err := createConnection(cfg, resolver) if err != nil { return nil, err diff --git a/docker/config_template.yaml b/docker/config_template.yaml index c3405f5848f..b214ffa74c0 100644 --- a/docker/config_template.yaml +++ b/docker/config_template.yaml @@ -73,6 +73,7 @@ persistence: maxConns: {{ default .Env.SQL_MAX_CONNS "20" }} maxIdleConns: {{ default .Env.SQL_MAX_IDLE_CONNS "20" }} maxConnLifetime: {{ default .Env.SQL_MAX_CONN_TIME "1h" }} + authPlugin: {{ default .Env.SQL_AUTH_PLUGIN "" }} tls: enabled: {{ default .Env.SQL_TLS_ENABLED "false" }} caFile: {{ default .Env.SQL_CA "" }} @@ -103,6 +104,7 @@ persistence: maxConns: {{ default .Env.SQL_VIS_MAX_CONNS "10" }} maxIdleConns: {{ default .Env.SQL_VIS_MAX_IDLE_CONNS "10" }} maxConnLifetime: {{ default .Env.SQL_VIS_MAX_CONN_TIME "1h" }} + authPlugin: {{ default .Env.SQL_VIS_AUTH_PLUGIN "" }} tls: enabled: {{ default .Env.SQL_TLS_ENABLED "false" }} caFile: {{ default .Env.SQL_CA "" }} @@ -122,6 +124,7 @@ persistence: maxConns: {{ default .Env.SQL_MAX_CONNS "20" }} maxIdleConns: {{ default .Env.SQL_MAX_IDLE_CONNS "20" }} maxConnLifetime: {{ default .Env.SQL_MAX_CONN_TIME "1h" }} + authPlugin: {{ default .Env.SQL_AUTH_PLUGIN "" }} tls: enabled: {{ default .Env.SQL_TLS_ENABLED "false" }} caFile: {{ default .Env.SQL_CA "" }} @@ -148,6 +151,7 @@ persistence: maxConns: {{ default .Env.SQL_VIS_MAX_CONNS "10" }} maxIdleConns: {{ default .Env.SQL_VIS_MAX_IDLE_CONNS "10" }} maxConnLifetime: {{ default .Env.SQL_VIS_MAX_CONN_TIME "1h" }} + authPlugin: {{ default .Env.SQL_VIS_AUTH_PLUGIN "" }} tls: enabled: {{ default .Env.SQL_TLS_ENABLED "false" }} caFile: {{ default .Env.SQL_CA "" }} diff --git a/go.mod b/go.mod index 2af32b58b52..c8e8db62fc8 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,9 @@ go 1.18 require ( cloud.google.com/go/storage v1.22.0 github.com/aws/aws-sdk-go v1.43.38 + github.com/aws/aws-sdk-go-v2 v1.16.3 + github.com/aws/aws-sdk-go-v2/config v1.15.5 + github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.1.20 github.com/blang/semver/v4 v4.0.0 github.com/brianvoe/gofakeit/v6 v6.15.0 github.com/cactus/go-statsd-client/statsd v0.0.0-20200423205355-cb0885a1018c @@ -55,6 +58,18 @@ require ( modernc.org/sqlite v1.16.0 ) +require ( + github.com/aws/aws-sdk-go-v2/credentials v1.12.0 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.3.11 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.11.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.16.4 // indirect + github.com/aws/smithy-go v1.11.2 // indirect +) + require ( cloud.google.com/go v0.100.2 // indirect cloud.google.com/go/compute v1.5.0 // indirect diff --git a/go.sum b/go.sum index 0b57947207c..f82e9f2a7a5 100644 --- a/go.sum +++ b/go.sum @@ -68,6 +68,30 @@ github.com/apache/thrift v0.0.0-20161221203622-b2a4d4ae21c7 h1:Fv9bK1Q+ly/ROk4aJ github.com/apache/thrift v0.0.0-20161221203622-b2a4d4ae21c7/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/aws/aws-sdk-go v1.43.38 h1:TDRjsUIsx2aeSuKkyzbwgltIRTbIKH6YCZbZ27JYhPk= github.com/aws/aws-sdk-go v1.43.38/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= +github.com/aws/aws-sdk-go-v2 v1.16.3 h1:0W1TSJ7O6OzwuEvIXAtJGvOeQ0SGAhcpxPN2/NK5EhM= +github.com/aws/aws-sdk-go-v2 v1.16.3/go.mod h1:ytwTPBG6fXTZLxxeeCCWj2/EMYp/xDUgX+OET6TLNNU= +github.com/aws/aws-sdk-go-v2/config v1.15.5 h1:P+xwhr6kabhxDTXTVH9YoHkqjLJ0wVVpIUHtFNr2hjU= +github.com/aws/aws-sdk-go-v2/config v1.15.5/go.mod h1:ZijHHh0xd/A+ZY53az0qzC5tT46kt4JVCePf2NX9Lk4= +github.com/aws/aws-sdk-go-v2/credentials v1.12.0 h1:4R/NqlcRFSkR0wxOhgHi+agGpbEr5qMCjn7VqUIJY+E= +github.com/aws/aws-sdk-go-v2/credentials v1.12.0/go.mod h1:9YWk7VW+eyKsoIL6/CljkTrNVWBSK9pkqOPUuijid4A= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.4 h1:FP8gquGeGHHdfY6G5llaMQDF+HAf20VKc8opRwmjf04= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.4/go.mod h1:u/s5/Z+ohUQOPXl00m2yJVyioWDECsbpXTQlaqSlufc= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.1.20 h1:XJ7N5UHcBoEqJw8iqa0t9h7cUom2xu8EEHWodps8Z+Y= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.1.20/go.mod h1:/U6pWj+/0bIpmBzCpw66cnydBDjmuxdgGxyxdQGZIp4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.10 h1:uFWgo6mGJI1n17nbcvSc6fxVuR3xLNqvXt12JCnEcT8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.10/go.mod h1:F+EZtuIwjlv35kRJPyBGcsA4f7bnSoz15zOQ2lJq1Z4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.4 h1:cnsvEKSoHN4oAN7spMMr0zhEW2MHnhAVpmqQg8E6UcM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.4/go.mod h1:8glyUqVIM4AmeenIsPo0oVh3+NUwnsQml2OFupfQW+0= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.11 h1:6cZRymlLEIlDTEB0+5+An6Zj1CKt6rSE69tOmFeu1nk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.3.11/go.mod h1:0MR+sS1b/yxsfAPvAESrw8NfwUoxMinDyw6EYR9BS2U= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.4 h1:b16QW0XWl0jWjLABFc1A+uh145Oqv+xDcObNk0iQgUk= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.4/go.mod h1:uKkN7qmSIsNJVyMtxNQoCEYMvFEXbOg9fwCJPdfp2u8= +github.com/aws/aws-sdk-go-v2/service/sso v1.11.4 h1:Uw5wBybFQ1UeA9ts0Y07gbv0ncZnIAyw858tDW0NP2o= +github.com/aws/aws-sdk-go-v2/service/sso v1.11.4/go.mod h1:cPDwJwsP4Kff9mldCXAmddjJL6JGQqtA3Mzer2zyr88= +github.com/aws/aws-sdk-go-v2/service/sts v1.16.4 h1:+xtV90n3abQmgzk1pS++FdxZTrPEDgQng6e4/56WR2A= +github.com/aws/aws-sdk-go-v2/service/sts v1.16.4/go.mod h1:lfSYenAXtavyX2A1LsViglqlG9eEFYxNryTZS5rn3QE= +github.com/aws/smithy-go v1.11.2 h1:eG/N+CcUMAvsdffgMvjMKwfyDzIkjM6pfxMJ8Mzc6mE= +github.com/aws/smithy-go v1.11.2/go.mod h1:3xHYmszWVx2c0kIwQeEVf9uSm4fYZt67FBJnwub1bgM= github.com/benbjohnson/clock v0.0.0-20160125162948-a620c1cc9866/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= diff --git a/tools/common/schema/types.go b/tools/common/schema/types.go index 49af50cf145..79dcc870eb1 100644 --- a/tools/common/schema/types.go +++ b/tools/common/schema/types.go @@ -89,6 +89,8 @@ const ( CLIOptDatabase = "database" // CLIOptPluginName is the cli option for plugin name CLIOptPluginName = "plugin" + // CLIOptAuthPluginName is the cli option for auth plugin name + CLIOptAuthPluginName = "sql-auth-plugin" // CLIOptConnectAttributes is the cli option for connect attributes (key/values via a url query string) CLIOptConnectAttributes = "connect-attributes" // CLIOptVersion is the cli option for version @@ -130,6 +132,8 @@ const ( CLIFlagDatabase = CLIOptDatabase + ", db" // CLIFlagPluginName is the cli flag for plugin name CLIFlagPluginName = CLIOptPluginName + ", pl" + // CLIFlagPluginName is the cli flag for sql auth plugin name + CLIFlagAuthPluginName = CLIOptAuthPluginName + ", ap" // CLIFlagConnectAttributes allows arbitrary connect attributes CLIFlagConnectAttributes = CLIOptConnectAttributes + ", ca" // CLIFlagVersion is the cli flag for version diff --git a/tools/sql/handler.go b/tools/sql/handler.go index bace5ab3bd5..a9b36f0931b 100644 --- a/tools/sql/handler.go +++ b/tools/sql/handler.go @@ -25,6 +25,7 @@ package sql import ( + "errors" "fmt" "net" "net/url" @@ -35,6 +36,7 @@ import ( "go.temporal.io/server/common/config" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" + SQLAuth "go.temporal.io/server/common/persistence/sql/sqlplugin/auth" "go.temporal.io/server/tools/common/schema" ) @@ -155,6 +157,9 @@ func parseConnectConfig(cli *cli.Context) (*config.SQL, error) { cfg.Password = cli.GlobalString(schema.CLIOptPassword) cfg.DatabaseName = cli.GlobalString(schema.CLIOptDatabase) cfg.PluginName = cli.GlobalString(schema.CLIOptPluginName) + cfg.AuthPlugin = &config.SQLAuthPlugin{ + Plugin: cli.GlobalString(schema.CLIOptAuthPluginName), + } if cfg.ConnectAttributes == nil { cfg.ConnectAttributes = map[string]string{} @@ -204,6 +209,12 @@ func ValidateConnectConfig(cfg *config.SQL) error { if cfg.DatabaseName == "" { return schema.NewConfigError("missing " + flag(schema.CLIOptDatabase) + " argument") } + if cfg.AuthPlugin != nil && cfg.AuthPlugin.Plugin != "" { + _, err := SQLAuth.LookupPlugin(cfg.AuthPlugin.Plugin) + if errors.Is(err, SQLAuth.ErrInvalidAuthPluginName) { + return schema.NewConfigError("invalid option for " + flag(schema.CLIOptAuthPluginName) + ": " + cfg.AuthPlugin.Plugin) + } + } return nil } diff --git a/tools/sql/main.go b/tools/sql/main.go index c4e5cc33274..ad15f4da949 100644 --- a/tools/sql/main.go +++ b/tools/sql/main.go @@ -98,6 +98,12 @@ func BuildCLIOptions() *cli.App { Usage: "name of the sql plugin", EnvVar: "SQL_PLUGIN", }, + cli.StringFlag{ + Name: schema.CLIFlagAuthPluginName, + Value: "", + Usage: "authentication plugin for sql database (supported: ['rds-iam-auth'])", + EnvVar: "SQL_AUTH_PLUGIN", + }, cli.BoolFlag{ Name: schema.CLIFlagQuiet, Usage: "Don't set exit status to 1 on error",