Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClientHandshake to return AuthInfo #956

Merged
merged 20 commits into from
Jan 9, 2017
Merged
4 changes: 1 addition & 3 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net
case <-ctx.Done():
return nil, nil, ctx.Err()
}
// TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else.
return conn, nil, nil
return conn, TLSInfo{conn.ConnectionState()}, nil
}

func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
Expand Down
161 changes: 161 additions & 0 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
package credentials

import (
"crypto/tls"
"net"
"testing"

"golang.org/x/net/context"
)

func TestTLSOverrideServerName(t *testing.T) {
Expand All @@ -58,4 +62,161 @@ func TestTLSClone(t *testing.T) {
if c.Info().ServerName != expectedServerName {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}

}

const tlsDir = "../test/testdata/"

type serverHandshake func(net.Conn) (AuthInfo, error)

func TestClientHandshakeReturnsAuthInfo(t *testing.T) {
done := make(chan AuthInfo, 1)
lis := launchServer(t, tlsServerHandshake, done)
defer lis.Close()
lisAddr := lis.Addr().String()
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
// wait until server sends serverAuthInfo or fails.
serverAuthInfo, ok := <-done
if !ok {
t.Fatalf("Error at server-side")
}
if !compare(clientAuthInfo, serverAuthInfo) {
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
}
}

func TestServerHandshakeReturnsAuthInfo(t *testing.T) {
done := make(chan AuthInfo, 1)
lis := launchServer(t, gRPCServerHandshake, done)
defer lis.Close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to the place right after listen().
If something fails before this line, lis will not be closed().

clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String())
// wait until server sends serverAuthInfo or fails.
serverAuthInfo, ok := <-done
if !ok {
t.Fatalf("Error at server-side")
}
if !compare(clientAuthInfo, serverAuthInfo) {
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo)
}
}

func TestServerAndClientHandshake(t *testing.T) {
done := make(chan AuthInfo, 1)
lis := launchServer(t, gRPCServerHandshake, done)
defer lis.Close()
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String())
// wait until server sends serverAuthInfo or fails.
serverAuthInfo, ok := <-done
if !ok {
t.Fatalf("Error at server-side")
}
if !compare(clientAuthInfo, serverAuthInfo) {
t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the above client side logic can be wrapped into a function too:

func clientHandle(hs func(context.Context, string, net.Conn) (net.Conn, AuthInfo, error))

}

func compare(a1, a2 AuthInfo) bool {
if a1.AuthType() != a2.AuthType() {
return false
}
switch a1.AuthType() {
case "tls":
state1 := a1.(TLSInfo).State
state2 := a2.(TLSInfo).State
if state1.Version == state2.Version &&
state1.HandshakeComplete == state2.HandshakeComplete &&
state1.CipherSuite == state2.CipherSuite &&
state1.NegotiatedProtocol == state2.NegotiatedProtocol {
return true
}
return false
default:
return false
}
}

func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
go serverHandle(t, hs, done, lis)
return lis
}

// Is run in a seperate goroutine.
func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) {
serverRawConn, err := lis.Accept()
if err != nil {
t.Errorf("Server failed to accept connection: %v", err)
close(done)
return
}
serverAuthInfo, err := hs(serverRawConn)
if err != nil {
t.Errorf("Server failed while handshake. Error: %v", err)
close(done)
return
}
done <- serverAuthInfo
}

func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo {
conn, err := net.Dial("tcp", lisAddr)
if err != nil {
t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
}
defer conn.Close()
clientAuthInfo, err := hs(conn, lisAddr)
if err != nil {
t.Fatalf("Error on client while handshake. Error: %v", err)
}
return clientAuthInfo
}

// Server handshake implementation in gRPC.
func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) {
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
return nil, err
}
_, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
if err != nil {
return nil, err
}
return serverAuthInfo, nil
}

// Client handshake implementation in gRPC.
func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) {
clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true})
_, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn)
if err != nil {
return nil, err
}
return authInfo, nil
}

func tlsServerHandshake(conn net.Conn) (AuthInfo, error) {
cert, err := tls.LoadX509KeyPair(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
return nil, err
}
serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
serverConn := tls.Server(conn, serverTLSConfig)
err = serverConn.Handshake()
if err != nil {
return nil, err
}
return TLSInfo{State: serverConn.ConnectionState()}, nil
}

func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
clientConn := tls.Client(conn, clientTLSConfig)
err := clientConn.Handshake()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if err := clientConn.Handshake(); err != nil {
  ...
}

if err != nil {
return nil, err
}
return TLSInfo{State: clientConn.ConnectionState()}, nil
}