Skip to content

Commit

Permalink
credentials/alts: defer ALTS stream creation until handshake time (#6077
Browse files Browse the repository at this point in the history
)
  • Loading branch information
matthewstevenson88 authored Mar 17, 2023
1 parent 6f44ae8 commit c84a500
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 17 deletions.
50 changes: 33 additions & 17 deletions credentials/alts/internal/handshaker/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,16 @@ func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
// and server options (server options struct does not exist now. When
// caller can provide endpoints, it should be created.

// altsHandshaker is used to complete a ALTS handshaking between client and
// altsHandshaker is used to complete an ALTS handshake between client and
// server. This handshaker talks to the ALTS handshaker service in the metadata
// server.
type altsHandshaker struct {
// RPC stream used to access the ALTS Handshaker service.
stream altsgrpc.HandshakerService_DoHandshakeClient
// the connection to the peer.
conn net.Conn
// a virtual connection to the ALTS handshaker service.
clientConn *grpc.ClientConn
// client handshake options.
clientOpts *ClientHandshakerOptions
// server handshake options.
Expand All @@ -154,39 +156,33 @@ type altsHandshaker struct {
side core.Side
}

// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// NewClientHandshaker creates a core.Handshaker that performs a client-side
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
// service in the metadata server.
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
stream: nil,
conn: c,
clientConn: conn,
clientOpts: opts,
side: core.ClientSide,
}, nil
}

// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// NewServerHandshaker creates a core.Handshaker that performs a server-side
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
// service in the metadata server.
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
if err != nil {
return nil, err
}
return &altsHandshaker{
stream: stream,
stream: nil,
conn: c,
clientConn: conn,
serverOpts: opts,
side: core.ServerSide,
}, nil
}

// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once
// ClientHandshake starts and completes a client ALTS handshake for GCP. Once
// done, ClientHandshake returns a secure connection.
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire() {
Expand All @@ -198,6 +194,16 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
}

// TODO(matthewstevenson88): Change unit tests to use public APIs so
// that h.stream can unconditionally be set based on h.clientConn.
if h.stream == nil {
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
}
h.stream = stream
}

// Create target identities from service account list.
targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
for _, account := range h.clientOpts.TargetServiceAccounts {
Expand Down Expand Up @@ -229,7 +235,7 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
return conn, authInfo, nil
}

// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once
// ServerHandshake starts and completes a server ALTS handshake for GCP. Once
// done, ServerHandshake returns a secure connection.
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire() {
Expand All @@ -241,6 +247,16 @@ func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credent
return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
}

// TODO(matthewstevenson88): Change unit tests to use public APIs so
// that h.stream can unconditionally be set based on h.clientConn.
if h.stream == nil {
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
}
h.stream = stream
}

p := make([]byte, frameLimit)
n, err := h.conn.Read(p)
if err != nil {
Expand Down
64 changes: 64 additions & 0 deletions credentials/alts/internal/handshaker/handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
grpc "google.golang.org/grpc"
core "google.golang.org/grpc/credentials/alts/internal"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
Expand Down Expand Up @@ -283,3 +285,65 @@ func (s) TestPeerNotResponding(t *testing.T) {
t.Errorf("ClientHandshake() = %v, want %v", got, want)
}
}

func (s) TestNewClientHandshaker(t *testing.T) {
conn := testutil.NewTestConn(nil, nil)
clientConn := &grpc.ClientConn{}
opts := &ClientHandshakerOptions{}
hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts)
if err != nil {
t.Errorf("NewClientHandshaker returned unexpected error: %v", err)
}
expectedHs := &altsHandshaker{
stream: nil,
conn: conn,
clientConn: clientConn,
clientOpts: opts,
serverOpts: nil,
side: core.ClientSide,
}
cmpOpts := []cmp.Option{
cmp.AllowUnexported(altsHandshaker{}),
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
}
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
t.Errorf("NewClientHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
}
if hs.(*altsHandshaker).stream != nil {
t.Errorf("NewClientHandshaker() returned handshaker with non-nil stream")
}
if hs.(*altsHandshaker).clientConn != clientConn {
t.Errorf("NewClientHandshaker() returned handshaker with unexpected clientConn")
}
}

func (s) TestNewServerHandshaker(t *testing.T) {
conn := testutil.NewTestConn(nil, nil)
clientConn := &grpc.ClientConn{}
opts := &ServerHandshakerOptions{}
hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts)
if err != nil {
t.Errorf("NewServerHandshaker returned unexpected error: %v", err)
}
expectedHs := &altsHandshaker{
stream: nil,
conn: conn,
clientConn: clientConn,
clientOpts: nil,
serverOpts: opts,
side: core.ServerSide,
}
cmpOpts := []cmp.Option{
cmp.AllowUnexported(altsHandshaker{}),
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
}
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
t.Errorf("NewServerHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
}
if hs.(*altsHandshaker).stream != nil {
t.Errorf("NewServerHandshaker() returned handshaker with non-nil stream")
}
if hs.(*altsHandshaker).clientConn != clientConn {
t.Errorf("NewServerHandshaker() returned handshaker with unexpected clientConn")
}
}

0 comments on commit c84a500

Please sign in to comment.