-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ClientHandshake to return AuthInfo #956
Changes from 10 commits
1165b1e
d1b12d3
0129b49
c62cf7b
5b3192c
ff332b6
915cb50
6c58b32
9d3e997
c980740
6fdee01
b792ae8
16853da
40952fe
1db9a22
848da09
49c5700
e7832cf
ecc30a5
74f10a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,11 @@ | |
package credentials | ||
|
||
import ( | ||
"crypto/tls" | ||
"net" | ||
"testing" | ||
|
||
"golang.org/x/net/context" | ||
) | ||
|
||
func TestTLSOverrideServerName(t *testing.T) { | ||
|
@@ -58,4 +62,151 @@ func TestTLSClone(t *testing.T) { | |
if c.Info().ServerName != expectedServerName { | ||
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) | ||
} | ||
|
||
} | ||
|
||
const tlsDir = "../test/testdata/" | ||
|
||
type serverHandshake func(net.Conn, *tls.ConnectionState) error | ||
|
||
func TestClientHandshakeReturnsAuthInfo(t *testing.T) { | ||
var serverConnState tls.ConnectionState | ||
done := make(chan error, 1) | ||
lisAddr := launchServer(t, &serverConnState, tlsServerHandshake, done) | ||
clientConnState := clientHandle(t, gRPCClientHandshake, lisAddr) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/clientConnState/clientAuthInfo |
||
// wait until server has populated the serverAuthInfo struct or failed. | ||
if err := <-done; err != nil { | ||
return | ||
} | ||
if !isEqualState(clientConnState, serverConnState) { | ||
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientConnState, serverConnState) | ||
} | ||
} | ||
|
||
func TestServerHandshakeReturnsAuthInfo(t *testing.T) { | ||
var serverConnState tls.ConnectionState | ||
done := make(chan error, 1) | ||
lisAddr := launchServer(t, &serverConnState, gRPCServerHandshake, done) | ||
clientConnState := clientHandle(t, tlsClientHandshake, lisAddr) | ||
// wait until server has populated the serverAuthInfo struct or failed. | ||
if err := <-done; err != nil { | ||
return | ||
} | ||
if !isEqualState(clientConnState, serverConnState) { | ||
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverConnState, clientConnState) | ||
} | ||
} | ||
|
||
func TestServerAndClientHandshake(t *testing.T) { | ||
var serverConnState tls.ConnectionState | ||
done := make(chan error, 1) | ||
lisAddr := launchServer(t, &serverConnState, gRPCServerHandshake, done) | ||
clientConnState := clientHandle(t, gRPCClientHandshake, lisAddr) | ||
// wait until server has populated the serverAuthInfo struct or failed. | ||
if err := <-done; err != nil { | ||
return | ||
} | ||
if !isEqualState(clientConnState, serverConnState) { | ||
t.Fatalf("Connection states returened by server: %v and client: %v aren't same", serverConnState, clientConnState) | ||
} | ||
} | ||
|
||
func isEqualState(state1, state2 tls.ConnectionState) bool { | ||
if state1.Version == state2.Version && | ||
state1.HandshakeComplete == state2.HandshakeComplete && | ||
state1.CipherSuite == state2.CipherSuite && | ||
state1.NegotiatedProtocol == state2.NegotiatedProtocol { | ||
return true | ||
} | ||
return false | ||
} | ||
|
||
func launchServer(t *testing.T, serverConnState *tls.ConnectionState, hs serverHandshake, done chan error) string { | ||
lis, err := net.Listen("tcp", "localhost:0") | ||
if err != nil { | ||
t.Fatalf("Failed to listen: %v", err) | ||
} | ||
go serverHandle(t, hs, serverConnState, done, lis) | ||
return lis.Addr().String() | ||
} | ||
|
||
// Is run in a seperate go routine. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. goroutine is one word. |
||
func serverHandle(t *testing.T, hs func(net.Conn, *tls.ConnectionState) error, serverConnState *tls.ConnectionState, done chan error, lis net.Listener) { | ||
defer lis.Close() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this to the place right after listen(). |
||
var err error | ||
defer func() { | ||
done <- err | ||
}() | ||
serverRawConn, err := lis.Accept() | ||
if err != nil { | ||
t.Errorf("Server failed to accept connection: %v", err) | ||
return | ||
} | ||
err = hs(serverRawConn, serverConnState) | ||
if err != nil { | ||
t.Errorf("Error at server-side while handshake. Error: %v", err) | ||
return | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and the above client side logic can be wrapped into a function too:
|
||
} | ||
|
||
func clientHandle(t *testing.T, hs func(net.Conn, string) (tls.ConnectionState, error), lisAddr string) tls.ConnectionState { | ||
conn, err := net.Dial("tcp", lisAddr) | ||
if err != nil { | ||
t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) | ||
} | ||
defer conn.Close() | ||
clientConnState, err := hs(conn, lisAddr) | ||
if err != nil { | ||
t.Fatalf("Error on client while handshake. Error: %v", err) | ||
} | ||
return clientConnState | ||
} | ||
|
||
// Server handshake implementation in gRPC. | ||
func gRPCServerHandshake(conn net.Conn, serverConnState *tls.ConnectionState) error { | ||
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") | ||
if err != nil { | ||
return err | ||
} | ||
_, serverAuthInfo, err := serverTLS.ServerHandshake(conn) | ||
if err != nil { | ||
return err | ||
} | ||
*serverConnState = serverAuthInfo.(TLSInfo).State | ||
return nil | ||
} | ||
|
||
// Client handshake implementation in gRPC. | ||
func gRPCClientHandshake(conn net.Conn, lisAddr string) (tls.ConnectionState, error) { | ||
clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true}) | ||
_, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn) | ||
if err != nil { | ||
return tls.ConnectionState{}, err | ||
} | ||
return authInfo.(TLSInfo).State, nil | ||
} | ||
|
||
func tlsServerHandshake(conn net.Conn, serverConnState *tls.ConnectionState) error { | ||
cert, err := tls.LoadX509KeyPair(tlsDir+"server1.pem", tlsDir+"server1.key") | ||
if err != nil { | ||
return err | ||
} | ||
serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}} | ||
serverConn := tls.Server(conn, serverTLSConfig) | ||
err = serverConn.Handshake() | ||
if err != nil { | ||
return err | ||
} | ||
*serverConnState = serverConn.ConnectionState() | ||
return nil | ||
} | ||
|
||
func tlsClientHandshake(conn net.Conn, _ string) (tls.ConnectionState, error) { | ||
clientTLSConfig := &tls.Config{InsecureSkipVerify: true} | ||
clientConn := tls.Client(conn, clientTLSConfig) | ||
err := clientConn.Handshake() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if err != nil { | ||
return tls.ConnectionState{}, err | ||
} | ||
return clientConn.ConnectionState(), nil | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make this a channel of authInfo. The way launchServer does (&serverConnState) is more like C/C++ way instead of Go idiom. In this way, serverHandshake does not need the second param.