Skip to content

Commit

Permalink
Add AWS region to the AWS Config Cache key (#6134)
Browse files Browse the repository at this point in the history
* Introduce aws region into the AWS config cache

Signed-off-by: Maksymilian Boguń <[email protected]>

* add CHANGELOG entry

Signed-off-by: Maksymilian Boguń <[email protected]>

* embedded AWS region into Authorization metadata

Signed-off-by: Maksymilian Boguń <[email protected]>

* move the fix to Unreleased version

Signed-off-by: Maksymilian Boguń <[email protected]>

* Fix indentation

Signed-off-by: Maksymilian Boguń <[email protected]>

---------

Signed-off-by: Maksymilian Boguń <[email protected]>
Signed-off-by: Jan Wozniak <[email protected]>
Co-authored-by: Jorge Turrado Ferrero <[email protected]>
Co-authored-by: Jan Wozniak <[email protected]>
  • Loading branch information
3 people authored Dec 4, 2024
1 parent c43af59 commit 1eaa34c
Show file tree
Hide file tree
Showing 14 changed files with 79 additions and 77 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Here is an overview of all new **experimental** features:
- **General**: Centralize and improve automaxprocs configuration with proper structured logging ([#5970](https://github.com/kedacore/keda/issues/5970))
- **General**: Paused ScaledObject count is reported correctly after operator restart ([#6321](https://github.com/kedacore/keda/issues/6321))
- **General**: ScaledJobs ready status set to true when recoverred problem ([#6329](https://github.com/kedacore/keda/pull/6329))
- **AWS Scalers**: Add AWS region to the AWS Config Cache key ([#6128](https://github.com/kedacore/keda/issues/6128))
- **Selenium Grid Scaler**: Exposes sum of pending and ongoing sessions to KDEA ([#6368](https://github.com/kedacore/keda/pull/6368))

### Deprecations
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/apache_kafka_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func getApacheKafkaClient(ctx context.Context, metadata apacheKafkaMetadata, log
case KafkaSASLTypeOAuthbearer:
return nil, errors.New("SASL/OAUTHBEARER is not implemented yet")
case KafkaSASLTypeMskIam:
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AWSRegion, metadata.AWSAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AWSAuthorization)
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/scalers/aws/aws_authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ type AuthorizationMetadata struct {
AwsSecretAccessKey string
AwsSessionToken string

AwsRegion string

// Deprecated
PodIdentityOwner bool
// Pod identity owner is confusing and it'll be removed when we get
Expand Down
31 changes: 13 additions & 18 deletions pkg/scalers/aws/aws_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,43 +39,33 @@ import (
// ErrAwsNoAccessKey is returned when awsAccessKeyID is missing.
var ErrAwsNoAccessKey = errors.New("awsAccessKeyID not found")

type awsConfigMetadata struct {
awsRegion string
awsAuthorization AuthorizationMetadata
}

var awsSharedCredentialsCache = newSharedConfigsCache()

// GetAwsConfig returns an *aws.Config for a given AuthorizationMetadata
// If AuthorizationMetadata uses static credentials or `aws` auth,
// we recover the *aws.Config from the shared cache. If not, we generate
// a new entry on each request
func GetAwsConfig(ctx context.Context, awsRegion string, awsAuthorization AuthorizationMetadata) (*aws.Config, error) {
metadata := &awsConfigMetadata{
awsRegion: awsRegion,
awsAuthorization: awsAuthorization,
}

if metadata.awsAuthorization.UsingPodIdentity ||
(metadata.awsAuthorization.AwsAccessKeyID != "" && metadata.awsAuthorization.AwsSecretAccessKey != "") {
return awsSharedCredentialsCache.GetCredentials(ctx, metadata.awsRegion, metadata.awsAuthorization)
func GetAwsConfig(ctx context.Context, awsAuthorization AuthorizationMetadata) (*aws.Config, error) {
if awsAuthorization.UsingPodIdentity ||
(awsAuthorization.AwsAccessKeyID != "" && awsAuthorization.AwsSecretAccessKey != "") {
return awsSharedCredentialsCache.GetCredentials(ctx, awsAuthorization)
}

// TODO, remove when aws-eks are removed
configOptions := make([]func(*config.LoadOptions) error, 0)
configOptions = append(configOptions, config.WithRegion(metadata.awsRegion))
configOptions = append(configOptions, config.WithRegion(awsAuthorization.AwsRegion))
cfg, err := config.LoadDefaultConfig(ctx, configOptions...)
if err != nil {
return nil, err
}

if !metadata.awsAuthorization.PodIdentityOwner {
if !awsAuthorization.PodIdentityOwner {
return &cfg, nil
}

if metadata.awsAuthorization.AwsRoleArn != "" {
if awsAuthorization.AwsRoleArn != "" {
stsSvc := sts.NewFromConfig(cfg)
stsCredentialProvider := stscreds.NewAssumeRoleProvider(stsSvc, metadata.awsAuthorization.AwsRoleArn, func(_ *stscreds.AssumeRoleOptions) {})
stsCredentialProvider := stscreds.NewAssumeRoleProvider(stsSvc, awsAuthorization.AwsRoleArn, func(_ *stscreds.AssumeRoleOptions) {})
cfg.Credentials = aws.NewCredentialsCache(stsCredentialProvider)
}
return &cfg, err
Expand All @@ -88,13 +78,18 @@ func GetAwsAuthorization(uniqueKey string, podIdentity kedav1alpha1.AuthPodIdent
TriggerUniqueKey: uniqueKey,
}

if val, ok := authParams["awsRegion"]; ok && val != "" {
meta.AwsRegion = val
}

if podIdentity.Provider == kedav1alpha1.PodIdentityProviderAws {
meta.UsingPodIdentity = true
if val, ok := authParams["awsRoleArn"]; ok && val != "" {
meta.AwsRoleArn = val
}
return meta, nil
}

// TODO, remove all the logic below and just keep the logic for
// parsing awsAccessKeyID, awsSecretAccessKey and awsSessionToken
// when aws-eks are removed
Expand Down
10 changes: 5 additions & 5 deletions pkg/scalers/aws/aws_config_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ func newSharedConfigsCache() sharedConfigCache {
// getCacheKey returns a unique key based on given AuthorizationMetadata.
// As it can contain sensitive data, the key is hashed to not expose secrets
func (a *sharedConfigCache) getCacheKey(awsAuthorization AuthorizationMetadata) string {
key := "keda"
key := "keda-" + awsAuthorization.AwsRegion
if awsAuthorization.AwsAccessKeyID != "" {
key = fmt.Sprintf("%s-%s-%s", awsAuthorization.AwsAccessKeyID, awsAuthorization.AwsSecretAccessKey, awsAuthorization.AwsSessionToken)
key = fmt.Sprintf("%s-%s-%s-%s", awsAuthorization.AwsAccessKeyID, awsAuthorization.AwsSecretAccessKey, awsAuthorization.AwsSessionToken, awsAuthorization.AwsRegion)
} else if awsAuthorization.AwsRoleArn != "" {
key = awsAuthorization.AwsRoleArn
key = fmt.Sprintf("%s-%s", awsAuthorization.AwsRoleArn, awsAuthorization.AwsRegion)
}
// to avoid sensitive data as key and to use a constant key size,
// we hash the key with sha3
Expand All @@ -86,7 +86,7 @@ func (a *sharedConfigCache) getCacheKey(awsAuthorization AuthorizationMetadata)
// sharing it between all the requests. To track if the *aws.Config is used by whom,
// every time when an scaler requests *aws.Config we register it inside
// the cached item.
func (a *sharedConfigCache) GetCredentials(ctx context.Context, awsRegion string, awsAuthorization AuthorizationMetadata) (*aws.Config, error) {
func (a *sharedConfigCache) GetCredentials(ctx context.Context, awsAuthorization AuthorizationMetadata) (*aws.Config, error) {
a.Lock()
defer a.Unlock()
key := a.getCacheKey(awsAuthorization)
Expand All @@ -97,7 +97,7 @@ func (a *sharedConfigCache) GetCredentials(ctx context.Context, awsRegion string
}

configOptions := make([]func(*config.LoadOptions) error, 0)
configOptions = append(configOptions, config.WithRegion(awsRegion))
configOptions = append(configOptions, config.WithRegion(awsAuthorization.AwsRegion))
cfg, err := config.LoadDefaultConfig(ctx, configOptions...)
if err != nil {
return nil, err
Expand Down
77 changes: 44 additions & 33 deletions pkg/scalers/aws/aws_config_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,84 +28,95 @@ import (
func TestGetCredentialsReturnNewItemAndStoreItIfNotExist(t *testing.T) {
cache := newSharedConfigsCache()
cache.logger = logr.Discard()
config := awsConfigMetadata{
awsRegion: "test-region",
awsAuthorization: AuthorizationMetadata{
TriggerUniqueKey: "test-key",
},
awsAuthorization := AuthorizationMetadata{
TriggerUniqueKey: "test-key",
AwsRegion: "test-region",
}
cacheKey := cache.getCacheKey(config.awsAuthorization)
_, err := cache.GetCredentials(context.Background(), config.awsRegion, config.awsAuthorization)
cacheKey := cache.getCacheKey(awsAuthorization)
_, err := cache.GetCredentials(context.Background(), awsAuthorization)
assert.NoError(t, err)
assert.Contains(t, cache.items, cacheKey)
assert.Contains(t, cache.items[cacheKey].usages, config.awsAuthorization.TriggerUniqueKey)
assert.Contains(t, cache.items[cacheKey].usages, awsAuthorization.TriggerUniqueKey)
}

func TestGetCredentialsReturnCachedItemIfExist(t *testing.T) {
cache := newSharedConfigsCache()
cache.logger = logr.Discard()
config := awsConfigMetadata{
awsRegion: "test1-region",
awsAuthorization: AuthorizationMetadata{
TriggerUniqueKey: "test1-key",
},
awsAuthorization := AuthorizationMetadata{
TriggerUniqueKey: "test1-key",
AwsRegion: "test1-region",
}
cfg := aws.Config{}
cfg.AppID = "test1-app"
cacheKey := cache.getCacheKey(config.awsAuthorization)
cacheKey := cache.getCacheKey(awsAuthorization)
cache.items[cacheKey] = cacheEntry{
config: &cfg,
usages: map[string]bool{
"other-usage": true,
},
}
configFromCache, err := cache.GetCredentials(context.Background(), config.awsRegion, config.awsAuthorization)
configFromCache, err := cache.GetCredentials(context.Background(), awsAuthorization)
assert.NoError(t, err)
assert.Equal(t, &cfg, configFromCache)
assert.Contains(t, cache.items[cacheKey].usages, config.awsAuthorization.TriggerUniqueKey)
assert.Contains(t, cache.items[cacheKey].usages, awsAuthorization.TriggerUniqueKey)
}

func TestRemoveCachedEntryRemovesCachedItemIfNotUsages(t *testing.T) {
cache := newSharedConfigsCache()
cache.logger = logr.Discard()
config := awsConfigMetadata{
awsRegion: "test2-region",
awsAuthorization: AuthorizationMetadata{
TriggerUniqueKey: "test2-key",
},
awsAuthorization := AuthorizationMetadata{
TriggerUniqueKey: "test2-key",
AwsRegion: "test2-region",
}
cfg := aws.Config{}
cfg.AppID = "test2-app"
cacheKey := cache.getCacheKey(config.awsAuthorization)
cacheKey := cache.getCacheKey(awsAuthorization)
cache.items[cacheKey] = cacheEntry{
config: &cfg,
usages: map[string]bool{
config.awsAuthorization.TriggerUniqueKey: true,
awsAuthorization.TriggerUniqueKey: true,
},
}
cache.RemoveCachedEntry(config.awsAuthorization)
cache.RemoveCachedEntry(awsAuthorization)
assert.NotContains(t, cache.items, cacheKey)
}

func TestRemoveCachedEntryNotRemoveCachedItemIfUsages(t *testing.T) {
cache := newSharedConfigsCache()
cache.logger = logr.Discard()
config := awsConfigMetadata{
awsRegion: "test3-region",
awsAuthorization: AuthorizationMetadata{
TriggerUniqueKey: "test3-key",
},
awsAuthorization := AuthorizationMetadata{
TriggerUniqueKey: "test3-key",
AwsRegion: "test3-region",
}
cfg := aws.Config{}
cfg.AppID = "test3-app"
cacheKey := cache.getCacheKey(config.awsAuthorization)
cacheKey := cache.getCacheKey(awsAuthorization)
cache.items[cacheKey] = cacheEntry{
config: &cfg,
usages: map[string]bool{
config.awsAuthorization.TriggerUniqueKey: true,
"other-usage": true,
awsAuthorization.TriggerUniqueKey: true,
"other-usage": true,
},
}
cache.RemoveCachedEntry(config.awsAuthorization)
cache.RemoveCachedEntry(awsAuthorization)
assert.Contains(t, cache.items, cacheKey)
}

func TestCredentialsShouldBeCachedPerRegion(t *testing.T) {
cache := newSharedConfigsCache()
cache.logger = logr.Discard()
awsAuthorization1 := AuthorizationMetadata{
TriggerUniqueKey: "test4-key",
AwsRegion: "test4-region1",
}
awsAuthorization2 := AuthorizationMetadata{
TriggerUniqueKey: "test4-key",
AwsRegion: "test4-region2",
}
cred1, err1 := cache.GetCredentials(context.Background(), awsAuthorization1)
cred2, err2 := cache.GetCredentials(context.Background(), awsAuthorization2)

assert.NoError(t, err1)
assert.NoError(t, err2)
assert.NotEqual(t, cred1, cred2, "Credentials should be stored per region")
}
18 changes: 5 additions & 13 deletions pkg/scalers/aws/aws_sigv4.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,12 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
}

// parseAwsAMPMetadata parses the data to get the AWS sepcific auth info and metadata
func parseAwsAMPMetadata(config *scalersconfig.ScalerConfig) (*awsConfigMetadata, error) {
meta := awsConfigMetadata{}

if val, ok := config.TriggerMetadata["awsRegion"]; ok && val != "" {
meta.awsRegion = val
}

func parseAwsAMPMetadata(config *scalersconfig.ScalerConfig) (*AuthorizationMetadata, error) {
auth, err := GetAwsAuthorization(config.TriggerUniqueKey, config.PodIdentity, config.TriggerMetadata, config.AuthParams, config.ResolvedEnv)
if err != nil {
return nil, err
}

meta.awsAuthorization = auth
return &meta, nil
return &auth, nil
}

// NewSigV4RoundTripper returns a new http.RoundTripper that will sign requests
Expand All @@ -100,11 +92,11 @@ func NewSigV4RoundTripper(config *scalersconfig.ScalerConfig) (http.RoundTripper
// which is probably the reason to create a SigV4RoundTripper.
// To prevent failures we check if the metadata is nil
// (missing AWS info) and we hide the error
metadata, _ := parseAwsAMPMetadata(config)
if metadata == nil {
awsAuthorization, _ := parseAwsAMPMetadata(config)
if awsAuthorization == nil {
return nil, nil
}
awsCfg, err := GetAwsConfig(context.Background(), metadata.awsRegion, metadata.awsAuthorization)
awsCfg, err := GetAwsConfig(context.Background(), *awsAuthorization)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/aws_cloudwatch_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func NewAwsCloudwatchScaler(ctx context.Context, config *scalersconfig.ScalerCon
}

func createCloudwatchClient(ctx context.Context, metadata *awsCloudwatchMetadata) (*cloudwatch.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AwsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)

if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/aws_dynamodb_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func parseAwsDynamoDBMetadata(config *scalersconfig.ScalerConfig) (*awsDynamoDBM
}

func createDynamoDBClient(ctx context.Context, metadata *awsDynamoDBMetadata) (*dynamodb.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AwsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/aws_dynamodb_streams_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func parseAwsDynamoDBStreamsMetadata(config *scalersconfig.ScalerConfig) (*awsDy
}

func createClientsForDynamoDBStreamsScaler(ctx context.Context, metadata *awsDynamoDBStreamsMetadata) (*dynamodb.Client, *dynamodbstreams.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AwsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)
if err != nil {
return nil, nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/aws_kinesis_stream_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func parseAwsKinesisStreamMetadata(config *scalersconfig.ScalerConfig, logger lo
}

func createKinesisClient(ctx context.Context, metadata *awsKinesisStreamMetadata) (*kinesis.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/aws_sqs_queue_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func parseAwsSqsQueueMetadata(config *scalersconfig.ScalerConfig) (*awsSqsQueueM
}

func createSqsClient(ctx context.Context, metadata *awsSqsQueueMetadata) (*sqs.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AwsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/kafka_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ func getKafkaClientConfig(ctx context.Context, metadata kafkaMetadata) (*sarama.
case KafkaSASLOAuthTokenProviderBearer:
config.Net.SASL.TokenProvider = kafka.OAuthBearerTokenProvider(metadata.username, metadata.password, metadata.oauthTokenEndpointURI, metadata.scopes, metadata.oauthExtensions)
case KafkaSASLOAuthTokenProviderAWSMSKIAM:
awsAuth, err := awsutils.GetAwsConfig(ctx, metadata.awsRegion, metadata.awsAuthorization)
awsAuth, err := awsutils.GetAwsConfig(ctx, metadata.awsAuthorization)
if err != nil {
return nil, fmt.Errorf("error getting AWS config: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/scaling/resolver/aws_secretmanager_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func (ash *AwsSecretManagerHandler) Initialize(ctx context.Context, client clien
if ash.secretManager.Region != "" {
awsRegion = ash.secretManager.Region
}
ash.awsMetadata.AwsRegion = awsRegion
podIdentity := ash.secretManager.PodIdentity
if podIdentity == nil {
podIdentity = &kedav1alpha1.AuthPodIdentity{}
Expand Down Expand Up @@ -100,7 +101,7 @@ func (ash *AwsSecretManagerHandler) Initialize(ctx context.Context, client clien
return fmt.Errorf("pod identity provider %s not supported", podIdentity.Provider)
}

config, err := awsutils.GetAwsConfig(ctx, awsRegion, ash.awsMetadata)
config, err := awsutils.GetAwsConfig(ctx, ash.awsMetadata)
if err != nil {
logger.Error(err, "Error getting credentials")
return err
Expand Down

0 comments on commit 1eaa34c

Please sign in to comment.