From b82cfc53908c1ba00dc65cda4231bb9590e066e1 Mon Sep 17 00:00:00 2001 From: Michel Laterman <82832767+michel-laterman@users.noreply.github.com> Date: Thu, 11 Apr 2024 09:02:06 -0700 Subject: [PATCH] Add support for repacking and merging tls.Config structs (#196) Add support for the Unpack operation for tls.Config attributes that are specified as strings but stored as other values. The existing Unpack operations are limited to string only and these will cause issues when trying to merge previously unpacked structs. --- transport/tlscommon/types.go | 92 ++++++++++++++++++++----------- transport/tlscommon/types_test.go | 31 +++++++++++ transport/tlscommon/versions.go | 23 ++++++-- 3 files changed, 110 insertions(+), 36 deletions(-) diff --git a/transport/tlscommon/types.go b/transport/tlscommon/types.go index 2dbf5359..5d178302 100644 --- a/transport/tlscommon/types.go +++ b/transport/tlscommon/types.go @@ -159,28 +159,37 @@ func (m TLSVerificationMode) MarshalText() ([]byte, error) { return nil, fmt.Errorf("could not marshal '%+v' to text", m) } -// Unpack unpacks the string into constants. +// Unpack unpacks the input into a TLSVerificationMode. func (m *TLSVerificationMode) Unpack(in interface{}) error { if in == nil { *m = VerifyFull return nil } - s, ok := in.(string) - if !ok { - return fmt.Errorf("verification mode must be an identifier") - } - if s == "" { - *m = VerifyFull - return nil + switch o := in.(type) { + case string: + if o == "" { + *m = VerifyFull + return nil + } + + mode, found := tlsVerificationModes[o] + if !found { + return fmt.Errorf("unknown verification mode '%v'", o) + } + *m = mode + case uint64: + *m = TLSVerificationMode(o) + default: + return fmt.Errorf("verification mode is an unknown type: %T", o) } + return nil +} - mode, found := tlsVerificationModes[s] - if !found { - return fmt.Errorf("unknown verification mode '%v'", s) +func (m *TLSVerificationMode) Validate() error { + if *m > VerifyStrict { + return fmt.Errorf("unsupported verification mode: %v", m) } - - *m = mode return nil } @@ -214,13 +223,20 @@ func (m *TLSClientAuth) Unpack(s string) error { type CipherSuite uint16 -func (cs *CipherSuite) Unpack(s string) error { - suite, found := tlsCipherSuites[s] - if !found { - return fmt.Errorf("invalid tls cipher suite '%v'", s) +func (cs *CipherSuite) Unpack(i interface{}) error { + switch o := i.(type) { + case string: + suite, found := tlsCipherSuites[o] + if !found { + return fmt.Errorf("invalid tls cipher suite '%v'", o) + } + + *cs = suite + case uint64: + *cs = CipherSuite(o) + default: + return fmt.Errorf("cipher suite is an unknown type: %T", o) } - - *cs = suite return nil } @@ -233,13 +249,20 @@ func (cs CipherSuite) String() string { type tlsCurveType tls.CurveID -func (ct *tlsCurveType) Unpack(s string) error { - t, found := tlsCurveTypes[s] - if !found { - return fmt.Errorf("invalid tls curve type '%v'", s) +func (ct *tlsCurveType) Unpack(i interface{}) error { + switch o := i.(type) { + case string: + t, found := tlsCurveTypes[o] + if !found { + return fmt.Errorf("invalid tls curve type '%v'", o) + } + + *ct = t + case uint64: + *ct = tlsCurveType(o) + default: + return fmt.Errorf("tls curve type is an unsupported input type: %T", o) } - - *ct = t return nil } @@ -252,13 +275,20 @@ func (r TLSRenegotiationSupport) String() string { return "<" + unknownType + ">" } -func (r *TLSRenegotiationSupport) Unpack(s string) error { - t, found := tlsRenegotiationSupportTypes[s] - if !found { - return fmt.Errorf("invalid tls renegotiation type '%v'", s) +func (r *TLSRenegotiationSupport) Unpack(i interface{}) error { + switch o := i.(type) { + case string: + t, found := tlsRenegotiationSupportTypes[o] + if !found { + return fmt.Errorf("invalid tls renegotiation type '%v'", o) + } + + *r = t + case uint64: + *r = TLSRenegotiationSupport(o) + default: + return fmt.Errorf("tls renegotation support is an unknown type: %T", o) } - - *r = t return nil } diff --git a/transport/tlscommon/types_test.go b/transport/tlscommon/types_test.go index 7a58c5e9..fb88d694 100644 --- a/transport/tlscommon/types_test.go +++ b/transport/tlscommon/types_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/elastic/elastic-agent-libs/config" + "github.com/elastic/go-ucfg" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -69,6 +70,36 @@ func TestLoadWithEmptyVerificationMode(t *testing.T) { assert.Equal(t, cfg.VerificationMode, VerifyFull) } +func TestRepackConfig(t *testing.T) { + cfg, err := load(` + enabled: true + verification_mode: certificate + supported_protocols: [TLSv1.1, TLSv1.2] + cipher_suites: + - RSA-AES-256-CBC-SHA + certificate_authorities: + - /path/to/ca.crt + certificate: /path/to/cert.crt + key: /path/to/key.crt + curve_types: + - P-521 + renegotiation: freely + ca_sha256: + - example + ca_trusted_fingerprint: fingerprint + `) + + assert.NoError(t, err) + assert.Equal(t, cfg.VerificationMode, VerifyCertificate) + + tmp, err := ucfg.NewFrom(cfg) + assert.NoError(t, err) + + err = tmp.Unpack(cfg) + assert.NoError(t, err) + assert.Equal(t, cfg.VerificationMode, VerifyCertificate) +} + func TestTLSClientAuthUnpack(t *testing.T) { tests := []struct { val string diff --git a/transport/tlscommon/versions.go b/transport/tlscommon/versions.go index e630ef4d..3e2ad498 100644 --- a/transport/tlscommon/versions.go +++ b/transport/tlscommon/versions.go @@ -38,12 +38,25 @@ func (v TLSVersion) Details() *TLSVersionDetails { } // Unpack transforms the string into a constant. -func (v *TLSVersion) Unpack(s string) error { - version, found := tlsProtocolVersions[s] - if !found { - return fmt.Errorf("invalid tls version '%v'", s) +func (v *TLSVersion) Unpack(i interface{}) error { + switch o := i.(type) { + case string: + version, found := tlsProtocolVersions[o] + if !found { + return fmt.Errorf("invalid tls version '%v'", o) + } + *v = version + case uint64: + *v = TLSVersion(o) + default: + return fmt.Errorf("tls version is an unknown type: %T", o) } + return nil +} - *v = version +func (v *TLSVersion) Validate() error { + if *v < TLSVersionMin || *v > TLSVersionMax { + return fmt.Errorf("unsupported tls version: %v", v) + } return nil }