Skip to content

Commit

Permalink
Add support for repacking and merging tls.Config structs (#196)
Browse files Browse the repository at this point in the history
Add support for the Unpack operation for tls.Config attributes that are
specified as strings but stored as other values. The existing Unpack
operations are limited to string only and these will cause issues when
trying to merge previously unpacked structs.
  • Loading branch information
michel-laterman authored Apr 11, 2024
1 parent ccbaaef commit b82cfc5
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 36 deletions.
92 changes: 61 additions & 31 deletions transport/tlscommon/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,28 +159,37 @@ func (m TLSVerificationMode) MarshalText() ([]byte, error) {
return nil, fmt.Errorf("could not marshal '%+v' to text", m)
}

// Unpack unpacks the string into constants.
// Unpack unpacks the input into a TLSVerificationMode.
func (m *TLSVerificationMode) Unpack(in interface{}) error {
if in == nil {
*m = VerifyFull
return nil
}

s, ok := in.(string)
if !ok {
return fmt.Errorf("verification mode must be an identifier")
}
if s == "" {
*m = VerifyFull
return nil
switch o := in.(type) {
case string:
if o == "" {
*m = VerifyFull
return nil
}

mode, found := tlsVerificationModes[o]
if !found {
return fmt.Errorf("unknown verification mode '%v'", o)
}
*m = mode
case uint64:
*m = TLSVerificationMode(o)
default:
return fmt.Errorf("verification mode is an unknown type: %T", o)
}
return nil
}

mode, found := tlsVerificationModes[s]
if !found {
return fmt.Errorf("unknown verification mode '%v'", s)
func (m *TLSVerificationMode) Validate() error {
if *m > VerifyStrict {
return fmt.Errorf("unsupported verification mode: %v", m)
}

*m = mode
return nil
}

Expand Down Expand Up @@ -214,13 +223,20 @@ func (m *TLSClientAuth) Unpack(s string) error {

type CipherSuite uint16

func (cs *CipherSuite) Unpack(s string) error {
suite, found := tlsCipherSuites[s]
if !found {
return fmt.Errorf("invalid tls cipher suite '%v'", s)
func (cs *CipherSuite) Unpack(i interface{}) error {
switch o := i.(type) {
case string:
suite, found := tlsCipherSuites[o]
if !found {
return fmt.Errorf("invalid tls cipher suite '%v'", o)
}

*cs = suite
case uint64:
*cs = CipherSuite(o)
default:
return fmt.Errorf("cipher suite is an unknown type: %T", o)
}

*cs = suite
return nil
}

Expand All @@ -233,13 +249,20 @@ func (cs CipherSuite) String() string {

type tlsCurveType tls.CurveID

func (ct *tlsCurveType) Unpack(s string) error {
t, found := tlsCurveTypes[s]
if !found {
return fmt.Errorf("invalid tls curve type '%v'", s)
func (ct *tlsCurveType) Unpack(i interface{}) error {
switch o := i.(type) {
case string:
t, found := tlsCurveTypes[o]
if !found {
return fmt.Errorf("invalid tls curve type '%v'", o)
}

*ct = t
case uint64:
*ct = tlsCurveType(o)
default:
return fmt.Errorf("tls curve type is an unsupported input type: %T", o)
}

*ct = t
return nil
}

Expand All @@ -252,13 +275,20 @@ func (r TLSRenegotiationSupport) String() string {
return "<" + unknownType + ">"
}

func (r *TLSRenegotiationSupport) Unpack(s string) error {
t, found := tlsRenegotiationSupportTypes[s]
if !found {
return fmt.Errorf("invalid tls renegotiation type '%v'", s)
func (r *TLSRenegotiationSupport) Unpack(i interface{}) error {
switch o := i.(type) {
case string:
t, found := tlsRenegotiationSupportTypes[o]
if !found {
return fmt.Errorf("invalid tls renegotiation type '%v'", o)
}

*r = t
case uint64:
*r = TLSRenegotiationSupport(o)
default:
return fmt.Errorf("tls renegotation support is an unknown type: %T", o)
}

*r = t
return nil
}

Expand Down
31 changes: 31 additions & 0 deletions transport/tlscommon/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"testing"

"github.com/elastic/elastic-agent-libs/config"
"github.com/elastic/go-ucfg"
"github.com/stretchr/testify/assert"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -69,6 +70,36 @@ func TestLoadWithEmptyVerificationMode(t *testing.T) {
assert.Equal(t, cfg.VerificationMode, VerifyFull)
}

func TestRepackConfig(t *testing.T) {
cfg, err := load(`
enabled: true
verification_mode: certificate
supported_protocols: [TLSv1.1, TLSv1.2]
cipher_suites:
- RSA-AES-256-CBC-SHA
certificate_authorities:
- /path/to/ca.crt
certificate: /path/to/cert.crt
key: /path/to/key.crt
curve_types:
- P-521
renegotiation: freely
ca_sha256:
- example
ca_trusted_fingerprint: fingerprint
`)

assert.NoError(t, err)
assert.Equal(t, cfg.VerificationMode, VerifyCertificate)

tmp, err := ucfg.NewFrom(cfg)
assert.NoError(t, err)

err = tmp.Unpack(cfg)
assert.NoError(t, err)
assert.Equal(t, cfg.VerificationMode, VerifyCertificate)
}

func TestTLSClientAuthUnpack(t *testing.T) {
tests := []struct {
val string
Expand Down
23 changes: 18 additions & 5 deletions transport/tlscommon/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,25 @@ func (v TLSVersion) Details() *TLSVersionDetails {
}

// Unpack transforms the string into a constant.
func (v *TLSVersion) Unpack(s string) error {
version, found := tlsProtocolVersions[s]
if !found {
return fmt.Errorf("invalid tls version '%v'", s)
func (v *TLSVersion) Unpack(i interface{}) error {
switch o := i.(type) {
case string:
version, found := tlsProtocolVersions[o]
if !found {
return fmt.Errorf("invalid tls version '%v'", o)
}
*v = version
case uint64:
*v = TLSVersion(o)
default:
return fmt.Errorf("tls version is an unknown type: %T", o)
}
return nil
}

*v = version
func (v *TLSVersion) Validate() error {
if *v < TLSVersionMin || *v > TLSVersionMax {
return fmt.Errorf("unsupported tls version: %v", v)
}
return nil
}

0 comments on commit b82cfc5

Please sign in to comment.