Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ClientHandshake to return AuthInfo #956

Merged
merged 20 commits into from
Jan 9, 2017
Merged
4 changes: 1 addition & 3 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net
case <-ctx.Done():
return nil, nil, ctx.Err()
}
// TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else.
return conn, nil, nil
return conn, TLSInfo{conn.ConnectionState()}, nil
}

func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
Expand Down
105 changes: 105 additions & 0 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
package credentials

import (
"crypto/tls"
"net"
"testing"

"golang.org/x/net/context"
)

func TestTLSOverrideServerName(t *testing.T) {
Expand All @@ -59,3 +63,104 @@ func TestTLSClone(t *testing.T) {
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
}
}

const tlsDir = "../test/testdata/"

func TestTLSClientHandshakeReturnsAuthInfo(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
t.Fatalf("Failed to create server TLS. Error: %v", err)
}
var serverAuthInfo TLSInfo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use tls.ConnectionState directly?

errChan := make(chan error)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make(chan error, 1) so that this goroutine can finish even though the test fails at line 101.

go func() {
var sErr error
defer func() {
errChan <- sErr
}()
serverRawConn, sErr := lis.Accept()
if sErr != nil {
t.Errorf("Server failed to accept connection: %v", sErr)
return
}
serverConn := tls.Server(serverRawConn, serverTLS.(*tlsCreds).config)
sErr = serverConn.Handshake()
if sErr != nil {
t.Errorf("Error on server while handshake. Error: %v", sErr)
return
}
serverAuthInfo = TLSInfo{serverConn.ConnectionState()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The body can be written in a separate function like:

func serverHandle(hs func(net.Conn) (net.Conn, AuthInfo, error)) {
    ...
}

}()
conn, err := net.Dial("tcp", lis.Addr().String())
if err != nil {
t.Fatalf("Client failed to connect to %v. Error: %v", lis.Addr().String(), err)
}
defer conn.Close()
c := NewTLS(&tls.Config{InsecureSkipVerify: true})
_, authInfo, err := c.ClientHandshake(context.Background(), lis.Addr().String(), conn)
if err != nil {
t.Fatalf("Error on client while handshake. Error: %v", err)
}
// wait until server has populated the serverAuthInfo struct.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.... populated the server AuthInfo or failed.

if err = <-errChan; err != nil {
return
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the above client side logic can be wrapped into a function too:

func clientHandle(hs func(context.Context, string, net.Conn) (net.Conn, AuthInfo, error))

if authInfo.(TLSInfo).State.Version != serverAuthInfo.State.Version {
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lis.Addr().String(), authInfo, serverAuthInfo)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%s for list.Addr().String()

}
}

func TestTLSServerHandshakeReturnsAuthInfo(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above comments apply to this test case too.

if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
t.Fatalf("Failed to create server TLS. Error: %v", err)
}
var serverAuthInfo AuthInfo
errChan := make(chan error)
go func() {
var sErr error
defer func() {
errChan <- sErr
}()
serverRawConn, sErr := lis.Accept()
if sErr != nil {
t.Errorf("Server failed to accept connection: %v", sErr)
return
}
_, serverAuthInfo, sErr = serverTLS.ServerHandshake(serverRawConn)
if sErr != nil {
t.Errorf("Error on server while handshake. Error: %v", sErr)
return
}
}()
conn, err := net.Dial("tcp", lis.Addr().String())
if err != nil {
t.Fatalf("Client failed to connect to %v. Error: %v", lis.Addr().String(), err)
}
defer conn.Close()
c := NewTLS(&tls.Config{InsecureSkipVerify: true})
clientConn := tls.Client(conn, c.(*tlsCreds).config)
err = clientConn.Handshake()
if err != nil {
t.Fatalf("Error on client while handshake. Error: %v", err)
}
authInfo := TLSInfo{clientConn.ConnectionState()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use tls.ConnectionState directly?

// wait until server has populated the serverAuthInfo struct.
if err = <-errChan; err != nil {
return
}
if authInfo.State.Version != serverAuthInfo.(TLSInfo).State.Version {
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, authInfo)
}

}