From 4ed348913c8a37cfab46b8684ff1ed794b13cb86 Mon Sep 17 00:00:00 2001 From: MakMukhi Date: Mon, 9 Jan 2017 13:29:20 -0800 Subject: [PATCH] ClientHandshake to return AuthInfo (#956) * Initial commit * Initial commit 2 * minor update * goimport update * resolved race condition * added test for TLSInfo on server side * Post review updates * port review changes debug debug * refactoring and added third function * post review changes * post review changes * post review updates * post review commit * post review commit * post review update * post review update * post review update * post review update * post review commit * post review update --- credentials/credentials.go | 4 +- credentials/credentials_test.go | 160 ++++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 3 deletions(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 5555ef024f67..4d45c3e3c7f2 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -165,9 +165,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net case <-ctx.Done(): return nil, nil, ctx.Err() } - // TODO(zhaoq): Omit the auth info for client now. It is more for - // information than anything else. - return conn, nil, nil + return conn, TLSInfo{conn.ConnectionState()}, nil } func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) { diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go index caf35b2feca8..a5db3867c8f3 100644 --- a/credentials/credentials_test.go +++ b/credentials/credentials_test.go @@ -34,7 +34,11 @@ package credentials import ( + "crypto/tls" + "net" "testing" + + "golang.org/x/net/context" ) func TestTLSOverrideServerName(t *testing.T) { @@ -58,4 +62,160 @@ func TestTLSClone(t *testing.T) { if c.Info().ServerName != expectedServerName { t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) } + +} + +const tlsDir = "../test/testdata/" + +type serverHandshake func(net.Conn) (AuthInfo, error) + +func TestClientHandshakeReturnsAuthInfo(t *testing.T) { + done := make(chan AuthInfo, 1) + lis := launchServer(t, tlsServerHandshake, done) + defer lis.Close() + lisAddr := lis.Addr().String() + clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr) + // wait until server sends serverAuthInfo or fails. + serverAuthInfo, ok := <-done + if !ok { + t.Fatalf("Error at server-side") + } + if !compare(clientAuthInfo, serverAuthInfo) { + t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) + } +} + +func TestServerHandshakeReturnsAuthInfo(t *testing.T) { + done := make(chan AuthInfo, 1) + lis := launchServer(t, gRPCServerHandshake, done) + defer lis.Close() + clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String()) + // wait until server sends serverAuthInfo or fails. + serverAuthInfo, ok := <-done + if !ok { + t.Fatalf("Error at server-side") + } + if !compare(clientAuthInfo, serverAuthInfo) { + t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo) + } +} + +func TestServerAndClientHandshake(t *testing.T) { + done := make(chan AuthInfo, 1) + lis := launchServer(t, gRPCServerHandshake, done) + defer lis.Close() + clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String()) + // wait until server sends serverAuthInfo or fails. + serverAuthInfo, ok := <-done + if !ok { + t.Fatalf("Error at server-side") + } + if !compare(clientAuthInfo, serverAuthInfo) { + t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo) + } +} + +func compare(a1, a2 AuthInfo) bool { + if a1.AuthType() != a2.AuthType() { + return false + } + switch a1.AuthType() { + case "tls": + state1 := a1.(TLSInfo).State + state2 := a2.(TLSInfo).State + if state1.Version == state2.Version && + state1.HandshakeComplete == state2.HandshakeComplete && + state1.CipherSuite == state2.CipherSuite && + state1.NegotiatedProtocol == state2.NegotiatedProtocol { + return true + } + return false + default: + return false + } +} + +func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + go serverHandle(t, hs, done, lis) + return lis +} + +// Is run in a seperate goroutine. +func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) { + serverRawConn, err := lis.Accept() + if err != nil { + t.Errorf("Server failed to accept connection: %v", err) + close(done) + return + } + serverAuthInfo, err := hs(serverRawConn) + if err != nil { + t.Errorf("Server failed while handshake. Error: %v", err) + close(done) + return + } + done <- serverAuthInfo +} + +func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo { + conn, err := net.Dial("tcp", lisAddr) + if err != nil { + t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) + } + defer conn.Close() + clientAuthInfo, err := hs(conn, lisAddr) + if err != nil { + t.Fatalf("Error on client while handshake. Error: %v", err) + } + return clientAuthInfo +} + +// Server handshake implementation in gRPC. +func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) { + serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") + if err != nil { + return nil, err + } + _, serverAuthInfo, err := serverTLS.ServerHandshake(conn) + if err != nil { + return nil, err + } + return serverAuthInfo, nil +} + +// Client handshake implementation in gRPC. +func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) { + clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true}) + _, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn) + if err != nil { + return nil, err + } + return authInfo, nil +} + +func tlsServerHandshake(conn net.Conn) (AuthInfo, error) { + cert, err := tls.LoadX509KeyPair(tlsDir+"server1.pem", tlsDir+"server1.key") + if err != nil { + return nil, err + } + serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}} + serverConn := tls.Server(conn, serverTLSConfig) + err = serverConn.Handshake() + if err != nil { + return nil, err + } + return TLSInfo{State: serverConn.ConnectionState()}, nil +} + +func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) { + clientTLSConfig := &tls.Config{InsecureSkipVerify: true} + clientConn := tls.Client(conn, clientTLSConfig) + if err := clientConn.Handshake(); err != nil { + return nil, err + } + return TLSInfo{State: clientConn.ConnectionState()}, nil }