From 6ea4f03923ace04838180ac0d6cb5aec661e96ca Mon Sep 17 00:00:00 2001 From: Michel Laterman <82832767+michel-laterman@users.noreply.github.com> Date: Tue, 16 Apr 2024 10:49:56 -0700 Subject: [PATCH] Add both uint64 and int64 to tlscommon types unpack methods (#198) Add both uint64 and int64 to tlscommon types unpack methods --- transport/tlscommon/server_config_test.go | 77 +++++++ transport/tlscommon/tls_test.go | 48 ++++ transport/tlscommon/types.go | 8 + transport/tlscommon/types_test.go | 267 ++++++++++++++++++++++ transport/tlscommon/versions.go | 2 + transport/tlscommon/versions_test.go | 45 ++++ 6 files changed, 447 insertions(+) diff --git a/transport/tlscommon/server_config_test.go b/transport/tlscommon/server_config_test.go index 2b7d9756..9f9bc8b9 100644 --- a/transport/tlscommon/server_config_test.go +++ b/transport/tlscommon/server_config_test.go @@ -173,3 +173,80 @@ func Test_ServerConfig_Repack(t *testing.T) { }) } } + +func Test_ServerConfig_RepackJSON(t *testing.T) { + tests := []struct { + name string + json string + auth *TLSClientAuth + }{{ + name: "with client auth", + json: `{ + "enabled": true, + "verification_mode": "certificate", + "supported_protocols": ["TLSv1.1", "TLSv1.2"], + "cipher_suites": ["RSA-AES-256-CBC-SHA"], + "certificate_authorities": ["/path/to/ca.crt"], + "certificate": "/path/to/cert.crt", + "key": "/path/to/key.crt", + "curve_types": "P-521", + "renegotiation": "freely", + "ca_sha256": ["example"], + "ca_trusted_fingerprint": "fingerprint", + "client_authentication": "optional" + }`, + auth: &optional, + }, { + name: "nil client auth", + json: `{ + "enabled": true, + "verification_mode": "certificate", + "supported_protocols": ["TLSv1.1", "TLSv1.2"], + "cipher_suites": ["RSA-AES-256-CBC-SHA"], + "certificate_authorities": ["/path/to/ca.crt"], + "certificate": "/path/to/cert.crt", + "key": "/path/to/key.crt", + "curve_types": "P-521", + "renegotiation": "freely", + "ca_sha256": ["example"], + "ca_trusted_fingerprint": "fingerprint" + }`, + auth: &required, + }, { + name: "nil client auth, no cas", + json: `{ + "enabled": true, + "verification_mode": "certificate", + "supported_protocols": ["TLSv1.1", "TLSv1.2"], + "cipher_suites": ["RSA-AES-256-CBC-SHA"], + "certificate": "/path/to/cert.crt", + "key": "/path/to/key.crt", + "curve_types": "P-521", + "renegotiation": "freely", + "ca_sha256": ["example"] + }`, + auth: nil, + }} + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cfg := mustLoadServerConfigJSON(t, tc.json) + if tc.auth != nil { + require.Equal(t, *tc.auth, *cfg.ClientAuth) + } else { + require.Nil(t, cfg.ClientAuth) + } + + tmp, err := ucfg.NewFrom(cfg) + require.NoError(t, err) + + err = tmp.Unpack(&cfg) + require.NoError(t, err) + if tc.auth != nil { + require.Equal(t, *tc.auth, *cfg.ClientAuth) + } else { + require.Nil(t, cfg.ClientAuth) + } + }) + } +} diff --git a/transport/tlscommon/tls_test.go b/transport/tlscommon/tls_test.go index 27df2b89..0de8872b 100644 --- a/transport/tlscommon/tls_test.go +++ b/transport/tlscommon/tls_test.go @@ -29,6 +29,9 @@ import ( "github.com/stretchr/testify/require" "github.com/elastic/elastic-agent-libs/config" + + ucfg "github.com/elastic/go-ucfg" + "github.com/elastic/go-ucfg/json" ) const ( @@ -76,6 +79,50 @@ func mustLoad(t *testing.T, yamlStr string) *Config { return cfg } +// copied from config.fromConfig +func cfgConvert(in *ucfg.Config) *config.C { + return (*config.C)(in) +} + +func loadJSON(jsonStr string) (*Config, error) { + var cfg Config + uc, err := json.NewConfig([]byte(jsonStr), ucfg.PathSep("."), ucfg.VarExp) + if err != nil { + return nil, err + } + + c := cfgConvert(uc) + + if err = c.Unpack(&cfg); err != nil { + return nil, err + } + return &cfg, nil +} + +func loadServerConfigJSON(jsonStr string) (*ServerConfig, error) { + var cfg ServerConfig + uc, err := json.NewConfig([]byte(jsonStr), ucfg.PathSep("."), ucfg.VarExp) + if err != nil { + return nil, err + } + + c := cfgConvert(uc) + + if err = c.Unpack(&cfg); err != nil { + return nil, err + } + return &cfg, nil +} + +func mustLoadServerConfigJSON(t *testing.T, jsonStr string) *ServerConfig { + t.Helper() + cfg, err := loadServerConfigJSON(jsonStr) + if err != nil { + t.Fatal(err) + } + return cfg +} + func writeTestFile(t *testing.T, content string) string { t.Helper() f, err := os.CreateTemp(t.TempDir(), "") @@ -647,6 +694,7 @@ mrPVWmOCMtwHJrO7kF1ENDgHPkhoZFcpFhu3lzOY7mhpW5mPZPVs87ZmI75G7zMV AcV8KJqa/7XTTpvIzXePw9FtSSux5SkU6iKAKqwUt82D1E73bbppSg== -----END CERTIFICATE----- ` + //nolint:gosec // testing key key := ` -----BEGIN RSA PRIVATE KEY----- Proc-Type: 4,ENCRYPTED diff --git a/transport/tlscommon/types.go b/transport/tlscommon/types.go index 513db1ed..531023f7 100644 --- a/transport/tlscommon/types.go +++ b/transport/tlscommon/types.go @@ -177,6 +177,8 @@ func (m *TLSVerificationMode) Unpack(in interface{}) error { return fmt.Errorf("unknown verification mode '%v'", o) } *m = mode + case int64: + *m = TLSVerificationMode(o) case uint64: *m = TLSVerificationMode(o) default: @@ -244,6 +246,8 @@ func (cs *CipherSuite) Unpack(i interface{}) error { } *cs = suite + case int64: + *cs = CipherSuite(o) case uint64: *cs = CipherSuite(o) default: @@ -270,6 +274,8 @@ func (ct *tlsCurveType) Unpack(i interface{}) error { } *ct = t + case int64: + *ct = tlsCurveType(o) case uint64: *ct = tlsCurveType(o) default: @@ -296,6 +302,8 @@ func (r *TLSRenegotiationSupport) Unpack(i interface{}) error { } *r = t + case int64: + *r = TLSRenegotiationSupport(o) case uint64: *r = TLSRenegotiationSupport(o) default: diff --git a/transport/tlscommon/types_test.go b/transport/tlscommon/types_test.go index fb88d694..e0c12cc7 100644 --- a/transport/tlscommon/types_test.go +++ b/transport/tlscommon/types_test.go @@ -18,6 +18,7 @@ package tlscommon import ( + "crypto/tls" "fmt" "testing" @@ -100,6 +101,32 @@ func TestRepackConfig(t *testing.T) { assert.Equal(t, cfg.VerificationMode, VerifyCertificate) } +func TestRepackConfigFromJSON(t *testing.T) { + cfg, err := loadJSON(`{ + "enabled": true, + "verification_mode": "certificate", + "supported_protocols": ["TLSv1.1", "TLSv1.2"], + "cipher_suites": ["RSA-AES-256-CBC-SHA"], + "certificate_authorities": ["/path/to/ca.crt"], + "certificate": "/path/to/cert.crt", + "key": "/path/to/key.crt", + "curve_types": "P-521", + "renegotiation": "freely", + "ca_sha256": ["example"], + "ca_trusted_fingerprint": "fingerprint" + }`) + + assert.NoError(t, err) + assert.Equal(t, cfg.VerificationMode, VerifyCertificate) + + tmp, err := ucfg.NewFrom(cfg) + assert.NoError(t, err) + + err = tmp.Unpack(cfg) + assert.NoError(t, err) + assert.Equal(t, cfg.VerificationMode, VerifyCertificate) +} + func TestTLSClientAuthUnpack(t *testing.T) { tests := []struct { val string @@ -262,3 +289,243 @@ func mustLoadServerConfig(t *testing.T, yamlStr string) *ServerConfig { } return cfg } + +func Test_TLSVerificaionMode_Unpack(t *testing.T) { + tests := []struct { + name string + hasErr bool + in interface{} + exp TLSVerificationMode + }{{ + name: "nil", + hasErr: false, + in: nil, + exp: VerifyFull, + }, { + name: "empty string", + hasErr: false, + in: "", + exp: VerifyFull, + }, { + name: "unknown string", + hasErr: true, + in: "unknown", + }, { + name: "string", + hasErr: false, + in: "strict", + exp: VerifyStrict, + }, { + name: "int64", + hasErr: false, + in: int64(1), + exp: VerifyNone, + }, { + name: "uint64", + hasErr: false, + in: uint64(1), + exp: VerifyNone, + }, { + name: "unknown type", + hasErr: true, + in: uint8(1), + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + v := new(TLSVerificationMode) + err := v.Unpack(tc.in) + if tc.hasErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.exp, *v) + } + }) + } +} + +func Test_TLSClientAuth_Unpack(t *testing.T) { + tests := []struct { + name string + hasErr bool + in interface{} + exp TLSClientAuth + }{{ + name: "nil", + hasErr: false, + in: nil, + exp: TLSClientAuthNone, + }, { + name: "empty string", + hasErr: false, + in: "", + exp: TLSClientAuthNone, + }, { + name: "unknown string", + hasErr: true, + in: "unknown", + }, { + name: "string", + hasErr: false, + in: "optional", + exp: TLSClientAuthOptional, + }, { + name: "int64", + hasErr: false, + in: int64(3), + exp: TLSClientAuthOptional, + }, { + name: "uint64", + hasErr: false, + in: uint64(3), + exp: TLSClientAuthOptional, + }, { + name: "unknown type", + hasErr: true, + in: uint8(1), + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + v := new(TLSClientAuth) + err := v.Unpack(tc.in) + if tc.hasErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.exp, *v) + } + }) + } +} + +func Test_CipherSuite_Unpack(t *testing.T) { + tests := []struct { + name string + hasErr bool + in interface{} + exp CipherSuite + }{{ + name: "unknown string", + hasErr: true, + in: "unknown", + }, { + name: "string", + hasErr: false, + in: "RSA-AES-128-CBC-SHA", + exp: CipherSuite(tls.TLS_RSA_WITH_AES_128_CBC_SHA), + }, { + name: "int64", + hasErr: false, + in: int64(47), + exp: CipherSuite(tls.TLS_RSA_WITH_AES_128_CBC_SHA), + }, { + name: "uint64", + hasErr: false, + in: uint64(47), + exp: CipherSuite(tls.TLS_RSA_WITH_AES_128_CBC_SHA), + }, { + name: "unknown type", + hasErr: true, + in: uint8(1), + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + v := new(CipherSuite) + err := v.Unpack(tc.in) + if tc.hasErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.exp, *v) + } + }) + } +} + +func Test_tlsCurveType_Unpack(t *testing.T) { + tests := []struct { + name string + hasErr bool + in interface{} + exp tlsCurveType + }{{ + name: "unknown string", + hasErr: true, + in: "unknown", + }, { + name: "string", + hasErr: false, + in: "P-256", + exp: tlsCurveType(tls.CurveP256), + }, { + name: "int64", + hasErr: false, + in: int64(23), + exp: tlsCurveType(tls.CurveP256), + }, { + name: "uint64", + hasErr: false, + in: uint64(23), + exp: tlsCurveType(tls.CurveP256), + }, { + name: "unknown type", + hasErr: true, + in: uint8(1), + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + v := new(tlsCurveType) + err := v.Unpack(tc.in) + if tc.hasErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.exp, *v) + } + }) + } +} + +func Test_TLSRenegotiationSupport_Unpack(t *testing.T) { + tests := []struct { + name string + hasErr bool + in interface{} + exp TLSRenegotiationSupport + }{{ + name: "unknown string", + hasErr: true, + in: "unknown", + }, { + name: "string", + hasErr: false, + in: "never", + exp: TLSRenegotiationSupport(tls.RenegotiateNever), + }, { + name: "int64", + hasErr: false, + in: int64(0), + exp: TLSRenegotiationSupport(tls.RenegotiateNever), + }, { + name: "uint64", + hasErr: false, + in: uint64(0), + exp: TLSRenegotiationSupport(tls.RenegotiateNever), + }, { + name: "unknown type", + hasErr: true, + in: uint8(1), + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + v := new(TLSRenegotiationSupport) + err := v.Unpack(tc.in) + if tc.hasErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.exp, *v) + } + }) + } +} diff --git a/transport/tlscommon/versions.go b/transport/tlscommon/versions.go index 3e2ad498..7c4d613a 100644 --- a/transport/tlscommon/versions.go +++ b/transport/tlscommon/versions.go @@ -46,6 +46,8 @@ func (v *TLSVersion) Unpack(i interface{}) error { return fmt.Errorf("invalid tls version '%v'", o) } *v = version + case int64: + *v = TLSVersion(o) case uint64: *v = TLSVersion(o) default: diff --git a/transport/tlscommon/versions_test.go b/transport/tlscommon/versions_test.go index 7f2b2e02..c779ca1f 100644 --- a/transport/tlscommon/versions_test.go +++ b/transport/tlscommon/versions_test.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -70,3 +71,47 @@ func TestTLSVersion(t *testing.T) { }) } } + +func Test_TLSVersion_Unpack(t *testing.T) { + tests := []struct { + name string + hasErr bool + in interface{} + exp TLSVersion + }{{ + name: "unknown string", + hasErr: true, + in: "unknown", + }, { + name: "string", + hasErr: false, + in: "TLSv1.2", + exp: TLSVersion12, + }, { + name: "int64", + hasErr: false, + in: int64(0x303), + exp: TLSVersion12, + }, { + name: "uint64", + hasErr: false, + in: uint64(0x303), + exp: TLSVersion12, + }, { + name: "unknown type", + hasErr: true, + in: uint8(1), + }} + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + v := new(TLSVersion) + err := v.Unpack(tc.in) + if tc.hasErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.exp, *v) + } + }) + } +}