Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClientHandshake to return AuthInfo #956

Merged
merged 20 commits into from
Jan 9, 2017
Merged
4 changes: 1 addition & 3 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net
case <-ctx.Done():
return nil, nil, ctx.Err()
}
// TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else.
return conn, nil, nil
return conn, TLSInfo{conn.ConnectionState()}, nil
}

func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
Expand Down
151 changes: 151 additions & 0 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
package credentials

import (
"crypto/tls"
"net"
"testing"

"golang.org/x/net/context"
)

func TestTLSOverrideServerName(t *testing.T) {
Expand All @@ -58,4 +62,151 @@ func TestTLSClone(t *testing.T) {
if c.Info().ServerName != expectedServerName {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}

}

const tlsDir = "../test/testdata/"

type serverHandshake func(net.Conn, *tls.ConnectionState) error

func TestClientHandshakeReturnsAuthInfo(t *testing.T) {
var serverConnState tls.ConnectionState
done := make(chan error, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this a channel of authInfo. The way launchServer does (&serverConnState) is more like C/C++ way instead of Go idiom. In this way, serverHandshake does not need the second param.

lisAddr := launchServer(t, &serverConnState, tlsServerHandshake, done)
clientConnState := clientHandle(t, gRPCClientHandshake, lisAddr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/clientConnState/clientAuthInfo

// wait until server has populated the serverAuthInfo struct or failed.
if err := <-done; err != nil {
return
}
if !isEqualState(clientConnState, serverConnState) {
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientConnState, serverConnState)
}
}

func TestServerHandshakeReturnsAuthInfo(t *testing.T) {
var serverConnState tls.ConnectionState
done := make(chan error, 1)
lisAddr := launchServer(t, &serverConnState, gRPCServerHandshake, done)
clientConnState := clientHandle(t, tlsClientHandshake, lisAddr)
// wait until server has populated the serverAuthInfo struct or failed.
if err := <-done; err != nil {
return
}
if !isEqualState(clientConnState, serverConnState) {
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverConnState, clientConnState)
}
}

func TestServerAndClientHandshake(t *testing.T) {
var serverConnState tls.ConnectionState
done := make(chan error, 1)
lisAddr := launchServer(t, &serverConnState, gRPCServerHandshake, done)
clientConnState := clientHandle(t, gRPCClientHandshake, lisAddr)
// wait until server has populated the serverAuthInfo struct or failed.
if err := <-done; err != nil {
return
}
if !isEqualState(clientConnState, serverConnState) {
t.Fatalf("Connection states returened by server: %v and client: %v aren't same", serverConnState, clientConnState)
}
}

func isEqualState(state1, state2 tls.ConnectionState) bool {
if state1.Version == state2.Version &&
state1.HandshakeComplete == state2.HandshakeComplete &&
state1.CipherSuite == state2.CipherSuite &&
state1.NegotiatedProtocol == state2.NegotiatedProtocol {
return true
}
return false
}

func launchServer(t *testing.T, serverConnState *tls.ConnectionState, hs serverHandshake, done chan error) string {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
go serverHandle(t, hs, serverConnState, done, lis)
return lis.Addr().String()
}

// Is run in a seperate go routine.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goroutine is one word.

func serverHandle(t *testing.T, hs func(net.Conn, *tls.ConnectionState) error, serverConnState *tls.ConnectionState, done chan error, lis net.Listener) {
defer lis.Close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to the place right after listen().
If something fails before this line, lis will not be closed().

var err error
defer func() {
done <- err
}()
serverRawConn, err := lis.Accept()
if err != nil {
t.Errorf("Server failed to accept connection: %v", err)
return
}
err = hs(serverRawConn, serverConnState)
if err != nil {
t.Errorf("Error at server-side while handshake. Error: %v", err)
return
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the above client side logic can be wrapped into a function too:

func clientHandle(hs func(context.Context, string, net.Conn) (net.Conn, AuthInfo, error))

}

func clientHandle(t *testing.T, hs func(net.Conn, string) (tls.ConnectionState, error), lisAddr string) tls.ConnectionState {
conn, err := net.Dial("tcp", lisAddr)
if err != nil {
t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
}
defer conn.Close()
clientConnState, err := hs(conn, lisAddr)
if err != nil {
t.Fatalf("Error on client while handshake. Error: %v", err)
}
return clientConnState
}

// Server handshake implementation in gRPC.
func gRPCServerHandshake(conn net.Conn, serverConnState *tls.ConnectionState) error {
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
return err
}
_, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
if err != nil {
return err
}
*serverConnState = serverAuthInfo.(TLSInfo).State
return nil
}

// Client handshake implementation in gRPC.
func gRPCClientHandshake(conn net.Conn, lisAddr string) (tls.ConnectionState, error) {
clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true})
_, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn)
if err != nil {
return tls.ConnectionState{}, err
}
return authInfo.(TLSInfo).State, nil
}

func tlsServerHandshake(conn net.Conn, serverConnState *tls.ConnectionState) error {
cert, err := tls.LoadX509KeyPair(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
return err
}
serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
serverConn := tls.Server(conn, serverTLSConfig)
err = serverConn.Handshake()
if err != nil {
return err
}
*serverConnState = serverConn.ConnectionState()
return nil
}

func tlsClientHandshake(conn net.Conn, _ string) (tls.ConnectionState, error) {
clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
clientConn := tls.Client(conn, clientTLSConfig)
err := clientConn.Handshake()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if err := clientConn.Handshake(); err != nil {
  ...
}

if err != nil {
return tls.ConnectionState{}, err
}
return clientConn.ConnectionState(), nil
}