From a7632e5db0bc7b2be582df07ae1e04c01c6fa67b Mon Sep 17 00:00:00 2001 From: Piotr Piotrowski Date: Fri, 22 Sep 2023 13:17:41 +0200 Subject: [PATCH] [ADDED] Setting TLS config with callbacks in Connect Signed-off-by: Piotr Piotrowski --- nats.go | 133 +++++++++++++++++++++++++++++----------------- test/conn_test.go | 98 ++++++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+), 49 deletions(-) diff --git a/nats.go b/nats.go index 721624c27..0be428932 100644 --- a/nats.go +++ b/nats.go @@ -90,55 +90,56 @@ const ( // Errors var ( - ErrConnectionClosed = errors.New("nats: connection closed") - ErrConnectionDraining = errors.New("nats: connection draining") - ErrDrainTimeout = errors.New("nats: draining connection timed out") - ErrConnectionReconnecting = errors.New("nats: connection reconnecting") - ErrSecureConnRequired = errors.New("nats: secure connection required") - ErrSecureConnWanted = errors.New("nats: secure connection not available") - ErrBadSubscription = errors.New("nats: invalid subscription") - ErrTypeSubscription = errors.New("nats: invalid subscription type") - ErrBadSubject = errors.New("nats: invalid subject") - ErrBadQueueName = errors.New("nats: invalid queue name") - ErrSlowConsumer = errors.New("nats: slow consumer, messages dropped") - ErrTimeout = errors.New("nats: timeout") - ErrBadTimeout = errors.New("nats: timeout invalid") - ErrAuthorization = errors.New("nats: authorization violation") - ErrAuthExpired = errors.New("nats: authentication expired") - ErrAuthRevoked = errors.New("nats: authentication revoked") - ErrAccountAuthExpired = errors.New("nats: account authentication expired") - ErrNoServers = errors.New("nats: no servers available for connection") - ErrJsonParse = errors.New("nats: connect message, json parse error") - ErrChanArg = errors.New("nats: argument needs to be a channel type") - ErrMaxPayload = errors.New("nats: maximum payload exceeded") - ErrMaxMessages = errors.New("nats: maximum messages delivered") - ErrSyncSubRequired = errors.New("nats: illegal call on an async subscription") - ErrMultipleTLSConfigs = errors.New("nats: multiple tls.Configs not allowed") - ErrNoInfoReceived = errors.New("nats: protocol exception, INFO not received") - ErrReconnectBufExceeded = errors.New("nats: outbound buffer limit exceeded") - ErrInvalidConnection = errors.New("nats: invalid connection") - ErrInvalidMsg = errors.New("nats: invalid message or message nil") - ErrInvalidArg = errors.New("nats: invalid argument") - ErrInvalidContext = errors.New("nats: invalid context") - ErrNoDeadlineContext = errors.New("nats: context requires a deadline") - ErrNoEchoNotSupported = errors.New("nats: no echo option not supported by this server") - ErrClientIDNotSupported = errors.New("nats: client ID not supported by this server") - ErrUserButNoSigCB = errors.New("nats: user callback defined without a signature handler") - ErrNkeyButNoSigCB = errors.New("nats: nkey defined without a signature handler") - ErrNoUserCB = errors.New("nats: user callback not defined") - ErrNkeyAndUser = errors.New("nats: user callback and nkey defined") - ErrNkeysNotSupported = errors.New("nats: nkeys not supported by the server") - ErrStaleConnection = errors.New("nats: " + STALE_CONNECTION) - ErrTokenAlreadySet = errors.New("nats: token and token handler both set") - ErrMsgNotBound = errors.New("nats: message is not bound to subscription/connection") - ErrMsgNoReply = errors.New("nats: message does not have a reply") - ErrClientIPNotSupported = errors.New("nats: client IP not supported by this server") - ErrDisconnected = errors.New("nats: server is disconnected") - ErrHeadersNotSupported = errors.New("nats: headers not supported by this server") - ErrBadHeaderMsg = errors.New("nats: message could not decode headers") - ErrNoResponders = errors.New("nats: no responders available for request") - ErrMaxConnectionsExceeded = errors.New("nats: server maximum connections exceeded") - ErrConnectionNotTLS = errors.New("nats: connection is not tls") + ErrConnectionClosed = errors.New("nats: connection closed") + ErrConnectionDraining = errors.New("nats: connection draining") + ErrDrainTimeout = errors.New("nats: draining connection timed out") + ErrConnectionReconnecting = errors.New("nats: connection reconnecting") + ErrSecureConnRequired = errors.New("nats: secure connection required") + ErrSecureConnWanted = errors.New("nats: secure connection not available") + ErrBadSubscription = errors.New("nats: invalid subscription") + ErrTypeSubscription = errors.New("nats: invalid subscription type") + ErrBadSubject = errors.New("nats: invalid subject") + ErrBadQueueName = errors.New("nats: invalid queue name") + ErrSlowConsumer = errors.New("nats: slow consumer, messages dropped") + ErrTimeout = errors.New("nats: timeout") + ErrBadTimeout = errors.New("nats: timeout invalid") + ErrAuthorization = errors.New("nats: authorization violation") + ErrAuthExpired = errors.New("nats: authentication expired") + ErrAuthRevoked = errors.New("nats: authentication revoked") + ErrAccountAuthExpired = errors.New("nats: account authentication expired") + ErrNoServers = errors.New("nats: no servers available for connection") + ErrJsonParse = errors.New("nats: connect message, json parse error") + ErrChanArg = errors.New("nats: argument needs to be a channel type") + ErrMaxPayload = errors.New("nats: maximum payload exceeded") + ErrMaxMessages = errors.New("nats: maximum messages delivered") + ErrSyncSubRequired = errors.New("nats: illegal call on an async subscription") + ErrMultipleTLSConfigs = errors.New("nats: multiple tls.Configs not allowed") + ErrClientCertOrRootCAsRequired = errors.New("nats: at least one of certCB or rootCAsCB must be set") + ErrNoInfoReceived = errors.New("nats: protocol exception, INFO not received") + ErrReconnectBufExceeded = errors.New("nats: outbound buffer limit exceeded") + ErrInvalidConnection = errors.New("nats: invalid connection") + ErrInvalidMsg = errors.New("nats: invalid message or message nil") + ErrInvalidArg = errors.New("nats: invalid argument") + ErrInvalidContext = errors.New("nats: invalid context") + ErrNoDeadlineContext = errors.New("nats: context requires a deadline") + ErrNoEchoNotSupported = errors.New("nats: no echo option not supported by this server") + ErrClientIDNotSupported = errors.New("nats: client ID not supported by this server") + ErrUserButNoSigCB = errors.New("nats: user callback defined without a signature handler") + ErrNkeyButNoSigCB = errors.New("nats: nkey defined without a signature handler") + ErrNoUserCB = errors.New("nats: user callback not defined") + ErrNkeyAndUser = errors.New("nats: user callback and nkey defined") + ErrNkeysNotSupported = errors.New("nats: nkeys not supported by the server") + ErrStaleConnection = errors.New("nats: " + STALE_CONNECTION) + ErrTokenAlreadySet = errors.New("nats: token and token handler both set") + ErrMsgNotBound = errors.New("nats: message is not bound to subscription/connection") + ErrMsgNoReply = errors.New("nats: message does not have a reply") + ErrClientIPNotSupported = errors.New("nats: client IP not supported by this server") + ErrDisconnected = errors.New("nats: server is disconnected") + ErrHeadersNotSupported = errors.New("nats: headers not supported by this server") + ErrBadHeaderMsg = errors.New("nats: message could not decode headers") + ErrNoResponders = errors.New("nats: no responders available for request") + ErrMaxConnectionsExceeded = errors.New("nats: server maximum connections exceeded") + ErrConnectionNotTLS = errors.New("nats: connection is not tls") ) // GetDefaultOptions returns default configuration options for the client. @@ -864,6 +865,40 @@ func Secure(tls ...*tls.Config) Option { } } +// ClientTLSConfig is an Option to set the TLS configuration for secure +// connections. It can be used to e.g. set TLS config with cert and root CAs +// from memory. For simple use case of loading cert and CAs from file, +// ClientCert and RootCAs options are more convenient. +// If Secure is not already set this will set it as well. +func ClientTLSConfig(certCB TLSCertHandler, rootCAsCB RootCAsHandler) Option { + return func(o *Options) error { + o.Secure = true + + if certCB == nil && rootCAsCB == nil { + return ErrClientCertOrRootCAsRequired + } + + // Smoke test the callbacks to fail early + // if they are not valid. + if certCB != nil { + if _, err := certCB(); err != nil { + return err + } + } + if rootCAsCB != nil { + if _, err := rootCAsCB(); err != nil { + return err + } + } + if o.TLSConfig == nil { + o.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + o.TLSCertCB = certCB + o.RootCAsCB = rootCAsCB + return nil + } +} + // RootCAs is a helper option to provide the RootCAs pool from a list of filenames. // If Secure is not already set this will set it as well. func RootCAs(file ...string) Option { diff --git a/test/conn_test.go b/test/conn_test.go index e621c902b..36d602c96 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -18,6 +18,7 @@ import ( "bytes" "crypto/tls" "crypto/x509" + "errors" "fmt" "net" "os" @@ -234,6 +235,103 @@ func TestServerSecureConnections(t *testing.T) { } } +func TestClientTLSConfig(t *testing.T) { + s, opts := RunServerWithConfig("./configs/tlsverify.conf") + defer s.Shutdown() + + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + secureURL := fmt.Sprintf("nats://%s", endpoint) + + // Make sure this fails + nc, err := nats.Connect(secureURL, nats.Secure()) + if err == nil { + nc.Close() + t.Fatal("Should have failed (TLS) connection without client certificate") + } + cert, err := os.ReadFile("./configs/certs/client-cert.pem") + if err != nil { + t.Fatal("Failed to read client certificate") + } + key, err := os.ReadFile("./configs/certs/client-key.pem") + if err != nil { + t.Fatal("Failed to read client key") + } + rootCAs, err := os.ReadFile("./configs/certs/ca.pem") + if err != nil { + t.Fatal("Failed to read root CAs") + } + + certCB := func() (tls.Certificate, error) { + cert, err := tls.X509KeyPair(cert, key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("nats: error loading client certificate: %w", err) + } + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return tls.Certificate{}, fmt.Errorf("nats: error parsing client certificate: %w", err) + } + return cert, nil + } + + caCB := func() (*x509.CertPool, error) { + pool := x509.NewCertPool() + ok := pool.AppendCertsFromPEM(rootCAs) + if !ok { + return nil, fmt.Errorf("nats: failed to parse root certificate from") + } + return pool, nil + } + + // Check parameters validity + _, err = nats.Connect(secureURL, nats.ClientTLSConfig(nil, nil)) + if !errors.Is(err, nats.ErrClientCertOrRootCAsRequired) { + t.Fatalf("Expected error %q, got %q", nats.ErrClientCertOrRootCAsRequired, err) + } + + certErr := &tls.CertificateVerificationError{} + // Should fail because of missing CA + _, err = nats.Connect(secureURL, + nats.ClientCert("./configs/certs/client-cert.pem", "./configs/certs/client-key.pem")) + if ok := errors.As(err, &certErr); !ok { + t.Fatalf("Expected error %q, got %q", nats.ErrClientCertOrRootCAsRequired, err) + } + + // Should fail because of missing certificate + _, err = nats.Connect(secureURL, + nats.ClientTLSConfig(nil, caCB)) + if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "certificate required") { + t.Fatalf("Expected missing certificate error; got: %s", err) + } + + nc, err = nats.Connect(secureURL, + nats.ClientTLSConfig(certCB, caCB)) + if err != nil { + t.Fatalf("Failed to create (TLS) connection: %v", err) + } + defer nc.Close() + + omsg := []byte("Hello!") + checkRecv := make(chan bool) + + received := 0 + nc.Subscribe("foo", func(m *nats.Msg) { + received++ + if !bytes.Equal(m.Data, omsg) { + t.Fatal("Message received does not match") + } + checkRecv <- true + }) + err = nc.Publish("foo", omsg) + if err != nil { + t.Fatalf("Failed to publish on secure (TLS) connection: %v", err) + } + nc.Flush() + + if err := Wait(checkRecv); err != nil { + t.Fatal("Failed to receive message") + } +} + func TestClientCertificate(t *testing.T) { s, opts := RunServerWithConfig("./configs/tlsverify.conf") defer s.Shutdown()