Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move usage flag configuration to agenthealth extension #1064

Merged
merged 6 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cfg/aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"

"github.com/aws/amazon-cloudwatch-agent/extension/agenthealth/handler/stats/provider"
"github.com/aws/amazon-cloudwatch-agent/extension/agenthealth/handler/stats/agent"
)

const (
Expand Down Expand Up @@ -116,7 +116,7 @@ func getSession(config *aws.Config) *session.Session {
if len(found) > 0 {
log.Printf("W! Unused shared config file(s) found: %v. If you would like to use them, "+
"please update your common-config.toml.", found)
provider.GetFlagsStats().SetFlag(provider.FlagSharedConfigFallback)
agent.UsageFlags().Set(agent.FlagSharedConfigFallback)
}
}
return ses
Expand Down
2 changes: 2 additions & 0 deletions extension/agenthealth/handler/stats/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,6 @@ func NewOperationsFilter(operations ...string) OperationsFilter {
type StatsConfig struct {
// Operations are the allowed operation names to gather stats for.
Operations []string `mapstructure:"operations,omitempty"`
// UsageFlags are the usage flags to set on start up.
UsageFlags map[Flag]any `mapstructure:"usage_flags,omitempty"`
}
188 changes: 188 additions & 0 deletions extension/agenthealth/handler/stats/agent/flag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package agent

import (
"encoding"
"errors"
"fmt"
"sync"

"github.com/aws/aws-sdk-go/aws"
)

var (
errUnsupportedFlag = errors.New("unsupported usage flag")
)

const (
FlagIMDSFallbackSuccess Flag = iota
FlagSharedConfigFallback
FlagAppSignal
FlagEnhancedContainerInsights
FlagRunningInContainer
FlagMode
FlagRegionType

flagIMDSFallbackSuccessStr = "imds_fallback_success"
flagSharedConfigFallbackStr = "shared_config_fallback"
flagAppSignalsStr = "app_signals"
flagEnhancedContainerInsightsStr = "enhanced_container_insights"
flagRunningInContainerStr = "running_in_container"
flagModeStr = "mode"
flagRegionTypeStr = "region_type"
jefchien marked this conversation as resolved.
Show resolved Hide resolved
)

type Flag int

var _ encoding.TextMarshaler = (*Flag)(nil)
var _ encoding.TextUnmarshaler = (*Flag)(nil)

func (f Flag) String() string {
switch f {
case FlagAppSignal:
return flagAppSignalsStr
case FlagEnhancedContainerInsights:
return flagEnhancedContainerInsightsStr
case FlagIMDSFallbackSuccess:
return flagIMDSFallbackSuccessStr
case FlagMode:
return flagModeStr
case FlagRegionType:
return flagRegionTypeStr
case FlagRunningInContainer:
return flagRunningInContainerStr
case FlagSharedConfigFallback:
return flagSharedConfigFallbackStr
}
return ""
}

func (f Flag) MarshalText() (text []byte, err error) {
s := f.String()
if s == "" {
return nil, fmt.Errorf("%w: %[2]T(%[2]d)", errUnsupportedFlag, f)
}
return []byte(s), nil
}

func (f *Flag) UnmarshalText(text []byte) error {
switch s := string(text); s {
case flagAppSignalsStr:
*f = FlagAppSignal
case flagEnhancedContainerInsightsStr:
*f = FlagEnhancedContainerInsights
case flagIMDSFallbackSuccessStr:
*f = FlagIMDSFallbackSuccess
case flagModeStr:
*f = FlagMode
case flagRegionTypeStr:
*f = FlagRegionType
case flagRunningInContainerStr:
*f = FlagRunningInContainer
case flagSharedConfigFallbackStr:
*f = FlagSharedConfigFallback
default:
return fmt.Errorf("%w: %s", errUnsupportedFlag, s)
}
return nil
}

var (
flagSingleton FlagSet
flagOnce sync.Once
)

// FlagSet is a getter/setter for flag/value pairs. Once a flag key is set, its value is immutable.
type FlagSet interface {
// IsSet returns if the flag is present in the backing map.
IsSet(flag Flag) bool
// GetString if the value stored with the flag is a string. If not, returns nil.
GetString(flag Flag) *string
// Set adds the Flag with an unused value.
Set(flag Flag)
// SetValue adds the Flag with a value.
SetValue(flag Flag, value any)
// SetValues adds each Flag/value pair.
SetValues(flags map[Flag]any)
// OnChange registers a callback that triggers on flag sets.
OnChange(callback func())
}

type flagSet struct {
m sync.Map
mu sync.RWMutex
callbacks []func()
}

var _ FlagSet = (*flagSet)(nil)

func (p *flagSet) IsSet(flag Flag) bool {
_, ok := p.m.Load(flag)
return ok
}

func (p *flagSet) GetString(flag Flag) *string {
value, ok := p.m.Load(flag)
if !ok {
return nil
}
var str string
str, ok = value.(string)
if !ok || str == "" {
return nil
}
return aws.String(str)
}

func (p *flagSet) Set(flag Flag) {
p.SetValue(flag, 1)
}

func (p *flagSet) SetValue(flag Flag, value any) {
if p.setWithValue(flag, value) {
p.notify()
}
}

func (p *flagSet) SetValues(m map[Flag]any) {
var changed bool
for flag, value := range m {
if p.setWithValue(flag, value) {
changed = true
}
}
if changed {
p.notify()
}
}

func (p *flagSet) setWithValue(flag Flag, value any) bool {
if !p.IsSet(flag) {
p.m.Store(flag, value)
return true
}
return false
}

func (p *flagSet) OnChange(f func()) {
p.mu.Lock()
defer p.mu.Unlock()
p.callbacks = append(p.callbacks, f)
}

func (p *flagSet) notify() {
p.mu.RLock()
defer p.mu.RUnlock()
for _, callback := range p.callbacks {
callback()
}
}

func UsageFlags() FlagSet {
flagOnce.Do(func() {
flagSingleton = &flagSet{}
})
return flagSingleton
}
90 changes: 90 additions & 0 deletions extension/agenthealth/handler/stats/agent/flag_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT

package agent

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestFlagSet(t *testing.T) {
fs := &flagSet{}
var notifyCount int
fs.OnChange(func() {
notifyCount++
})
assert.False(t, fs.IsSet(FlagIMDSFallbackSuccess))
assert.Nil(t, fs.GetString(FlagIMDSFallbackSuccess))
fs.Set(FlagIMDSFallbackSuccess)
assert.True(t, fs.IsSet(FlagIMDSFallbackSuccess))
assert.Nil(t, fs.GetString(FlagIMDSFallbackSuccess))
assert.Equal(t, 1, notifyCount)
// already set, so ignored
fs.SetValue(FlagIMDSFallbackSuccess, "ignores this")
assert.Nil(t, fs.GetString(FlagIMDSFallbackSuccess))
assert.Equal(t, 1, notifyCount)
fs.SetValues(map[Flag]any{
FlagMode: "test/mode",
FlagRegionType: "test/region-type",
})
assert.True(t, fs.IsSet(FlagMode))
assert.True(t, fs.IsSet(FlagRegionType))
got := fs.GetString(FlagMode)
assert.NotNil(t, got)
assert.Equal(t, "test/mode", *got)
got = fs.GetString(FlagRegionType)
assert.NotNil(t, got)
assert.Equal(t, "test/region-type", *got)
assert.Equal(t, 2, notifyCount)
fs.SetValues(map[Flag]any{
FlagRegionType: "other",
})
assert.NotNil(t, got)
assert.Equal(t, "test/region-type", *got)
assert.Equal(t, 2, notifyCount)
fs.SetValues(map[Flag]any{
FlagMode: "other/mode",
FlagRunningInContainer: true,
})
got = fs.GetString(FlagMode)
assert.NotNil(t, got)
assert.Equal(t, "test/mode", *got)
assert.True(t, fs.IsSet(FlagRunningInContainer))
assert.Equal(t, 3, notifyCount)
}

func TestFlag(t *testing.T) {
testCases := []struct {
flag Flag
str string
}{
{flag: FlagAppSignal, str: flagAppSignalsStr},
{flag: FlagEnhancedContainerInsights, str: flagEnhancedContainerInsightsStr},
{flag: FlagIMDSFallbackSuccess, str: flagIMDSFallbackSuccessStr},
{flag: FlagMode, str: flagModeStr},
{flag: FlagRegionType, str: flagRegionTypeStr},
{flag: FlagRunningInContainer, str: flagRunningInContainerStr},
{flag: FlagSharedConfigFallback, str: flagSharedConfigFallbackStr},
}
for _, testCase := range testCases {
flag := testCase.flag
got, err := flag.MarshalText()
assert.NoError(t, err)
assert.EqualValues(t, testCase.str, got)
assert.NoError(t, flag.UnmarshalText(got))
assert.Equal(t, flag, testCase.flag)
}
}

func TestInvalidFlag(t *testing.T) {
f := Flag(-1)
got, err := f.MarshalText()
assert.Error(t, err)
assert.ErrorIs(t, err, errUnsupportedFlag)
assert.Nil(t, got)
err = f.UnmarshalText([]byte("Flag(-1)"))
assert.Error(t, err)
assert.ErrorIs(t, err, errUnsupportedFlag)
}
1 change: 1 addition & 0 deletions extension/agenthealth/handler/stats/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func NewHandlers(logger *zap.Logger, cfg agent.StatsConfig) ([]awsmiddleware.Req
filter := agent.NewOperationsFilter(cfg.Operations...)
clientStats := client.NewHandler(filter)
stats := newStatsHandler(logger, filter, []agent.StatsProvider{clientStats, provider.GetProcessStats(), provider.GetFlagsStats()})
agent.UsageFlags().SetValues(cfg.UsageFlags)
return []awsmiddleware.RequestHandler{stats, clientStats}, []awsmiddleware.ResponseHandler{clientStats}
}

Expand Down
Loading
Loading