diff --git a/nats.go b/nats.go index 82b79730a..73709db88 100644 --- a/nats.go +++ b/nats.go @@ -311,6 +311,13 @@ type Options struct { // TLSCertCB is used to fetch and return custom tls certificate. TLSCertCB TLSCertHandler + // TLSHandshakeFirst is used to instruct the library perform + // the TLS handshake right after the connect and before receiving + // the INFO protocol from the server. If this option is enabled + // but the server is not configured to perform the TLS handshake + // first, the connection will fail. + TLSHandshakeFirst bool + // RootCAsCB is used to fetch and return a set of root certificate // authorities that clients use when verifying server certificates. RootCAsCB RootCAsHandler @@ -1315,6 +1322,17 @@ func SkipHostLookup() Option { } } +// TLSHandshakeFirst is an Option to perform the TLS handshake first, that is +// before receiving the INFO protocol. This requires the server to also be +// configured with such option, otherwise the connection will fail. +func TLSHandshakeFirst() Option { + return func(o *Options) error { + o.TLSHandshakeFirst = true + o.Secure = true + return nil + } +} + // Handler processing // SetDisconnectHandler will set the disconnect event handler. @@ -1481,6 +1499,12 @@ func (o Options) Connect() (*Conn, error) { } } + // If the TLSHandshakeFirst option is specified, make sure that + // the Secure boolean is true. + if nc.Opts.TLSHandshakeFirst { + nc.Opts.Secure = true + } + if err := nc.setupServerPool(); err != nil { return nil, err } @@ -2235,6 +2259,14 @@ func (nc *Conn) processConnectInit() error { // Set our status to connecting. nc.changeConnStatus(CONNECTING) + // If we need to have a TLS connection and want the TLS handshake to occur + // first, do it now. + if nc.Opts.Secure && nc.Opts.TLSHandshakeFirst { + if err := nc.makeTLSConn(); err != nil { + return err + } + } + // Process the INFO protocol received from the server err := nc.processExpectedInfo() if err != nil { @@ -2351,8 +2383,13 @@ func (nc *Conn) checkForSecure() error { o.Secure = true } - // Need to rewrap with bufio if o.Secure { + // If TLS handshake first is true, we have already done + // the handshake, so we are done here. + if o.TLSHandshakeFirst { + return nil + } + // Need to rewrap with bufio if err := nc.makeTLSConn(); err != nil { return err } diff --git a/test/conn_test.go b/test/conn_test.go index 995b3d1f3..f4bdd877e 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -2863,3 +2863,112 @@ func TestConnStatusChangedEvents(t *testing.T) { time.Sleep(100 * time.Millisecond) }) } + +func TestTLSHandshakeFirst(t *testing.T) { + s, opts := RunServerWithConfig("./configs/tls.conf") + defer s.Shutdown() + + secureURL := fmt.Sprintf("tls://derek:porkchop@localhost:%d", opts.Port) + nc, err := nats.Connect(secureURL, + nats.RootCAs("./configs/certs/ca.pem"), + nats.TLSHandshakeFirst()) + if err == nil || !strings.Contains(err.Error(), "TLS handshake") { + if err == nil { + nc.Close() + } + t.Fatalf("Expected error about not being a TLS handshake, got %v", err) + } + + tc := &server.TLSConfigOpts{ + CertFile: "./configs/certs/server.pem", + KeyFile: "./configs/certs/key.pem", + } + tlsConf, err := server.GenTLSConfig(tc) + if err != nil { + t.Fatalf("Can't build TLCConfig: %v", err) + } + tlsConf.ServerName = "localhost" + + // Start a mockup server that will do the TLS handshake first + // and then send the INFO protcol. + l, e := net.Listen("tcp", ":0") + if e != nil { + t.Fatal("Could not listen on an ephemeral port") + } + tl := l.(*net.TCPListener) + defer tl.Close() + + addr := tl.Addr().(*net.TCPAddr) + + errCh := make(chan error, 1) + doneCh := make(chan struct{}) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + conn, err := l.Accept() + if err != nil { + errCh <- fmt.Errorf("error accepting client connection: %v", err) + return + } + defer conn.Close() + + // Do the TLS handshake now. + conn = tls.Server(conn, tlsConf) + tlsconn := conn.(*tls.Conn) + if err := tlsconn.Handshake(); err != nil { + errCh <- fmt.Errorf("Server error during handshake: %v", err) + return + } + + // Send back the INFO + info := fmt.Sprintf("INFO {\"server_id\":\"foobar\",\"host\":\"localhost\",\"port\":%d,\"auth_required\":false,\"tls_required\":true,\"tls_available\":true,\"tls_verify\":true,\"max_payload\":1048576}\r\n", addr.Port) + tlsconn.Write([]byte(info)) + + // Read connect and ping commands sent from the client + line := make([]byte, 256) + _, err = tlsconn.Read(line) + if err != nil { + errCh <- fmt.Errorf("expected CONNECT and PING from client, got: %s", err) + return + } + tlsconn.Write([]byte("PONG\r\n")) + + // Wait for the signal that client is ok + <-doneCh + // Server is done now. + errCh <- nil + }() + + time.Sleep(100 * time.Millisecond) + + secureURL = fmt.Sprintf("tls://derek:porkchop@localhost:%d", addr.Port) + nc, err = nats.Connect(secureURL, + nats.RootCAs("./configs/certs/ca.pem"), + nats.TLSHandshakeFirst()) + if err != nil { + wg.Wait() + e := <-errCh + t.Fatalf("Unexpected error: %v (server error=%s)", err, e.Error()) + } + + state, err := nc.TLSConnectionState() + if err != nil { + t.Fatalf("Expected connection state: %v", err) + } + if !state.HandshakeComplete { + t.Fatalf("Expected valid connection state") + } + nc.Close() + + close(doneCh) + wg.Wait() + select { + case e := <-errCh: + if e != nil { + t.Fatalf("Error from server: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Server did not exit") + } +}