mirror of https://github.com/grpc/grpc-go.git
credentials/alts: defer ALTS stream creation until handshake time (#6077)
This commit is contained in:
parent
6f44ae89b1
commit
c84a5005d9
|
@ -138,7 +138,7 @@ func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
|
|||
// and server options (server options struct does not exist now. When
|
||||
// caller can provide endpoints, it should be created.
|
||||
|
||||
// altsHandshaker is used to complete a ALTS handshaking between client and
|
||||
// altsHandshaker is used to complete an ALTS handshake between client and
|
||||
// server. This handshaker talks to the ALTS handshaker service in the metadata
|
||||
// server.
|
||||
type altsHandshaker struct {
|
||||
|
@ -146,6 +146,8 @@ type altsHandshaker struct {
|
|||
stream altsgrpc.HandshakerService_DoHandshakeClient
|
||||
// the connection to the peer.
|
||||
conn net.Conn
|
||||
// a virtual connection to the ALTS handshaker service.
|
||||
clientConn *grpc.ClientConn
|
||||
// client handshake options.
|
||||
clientOpts *ClientHandshakerOptions
|
||||
// server handshake options.
|
||||
|
@ -154,39 +156,33 @@ type altsHandshaker struct {
|
|||
side core.Side
|
||||
}
|
||||
|
||||
// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC
|
||||
// stub created using the passed conn and used to talk to the ALTS Handshaker
|
||||
// NewClientHandshaker creates a core.Handshaker that performs a client-side
|
||||
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
|
||||
// service in the metadata server.
|
||||
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
|
||||
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &altsHandshaker{
|
||||
stream: stream,
|
||||
stream: nil,
|
||||
conn: c,
|
||||
clientConn: conn,
|
||||
clientOpts: opts,
|
||||
side: core.ClientSide,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC
|
||||
// stub created using the passed conn and used to talk to the ALTS Handshaker
|
||||
// NewServerHandshaker creates a core.Handshaker that performs a server-side
|
||||
// ALTS handshake by acting as a proxy between the peer and the ALTS handshaker
|
||||
// service in the metadata server.
|
||||
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
|
||||
stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &altsHandshaker{
|
||||
stream: stream,
|
||||
stream: nil,
|
||||
conn: c,
|
||||
clientConn: conn,
|
||||
serverOpts: opts,
|
||||
side: core.ServerSide,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once
|
||||
// ClientHandshake starts and completes a client ALTS handshake for GCP. Once
|
||||
// done, ClientHandshake returns a secure connection.
|
||||
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
|
||||
if !acquire() {
|
||||
|
@ -198,6 +194,16 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
|
|||
return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
|
||||
}
|
||||
|
||||
// TODO(matthewstevenson88): Change unit tests to use public APIs so
|
||||
// that h.stream can unconditionally be set based on h.clientConn.
|
||||
if h.stream == nil {
|
||||
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
|
||||
}
|
||||
h.stream = stream
|
||||
}
|
||||
|
||||
// Create target identities from service account list.
|
||||
targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
|
||||
for _, account := range h.clientOpts.TargetServiceAccounts {
|
||||
|
@ -229,7 +235,7 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
|
|||
return conn, authInfo, nil
|
||||
}
|
||||
|
||||
// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once
|
||||
// ServerHandshake starts and completes a server ALTS handshake for GCP. Once
|
||||
// done, ServerHandshake returns a secure connection.
|
||||
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
|
||||
if !acquire() {
|
||||
|
@ -241,6 +247,16 @@ func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credent
|
|||
return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
|
||||
}
|
||||
|
||||
// TODO(matthewstevenson88): Change unit tests to use public APIs so
|
||||
// that h.stream can unconditionally be set based on h.clientConn.
|
||||
if h.stream == nil {
|
||||
stream, err := altsgrpc.NewHandshakerServiceClient(h.clientConn).DoHandshake(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to establish stream to ALTS handshaker service: %v", err)
|
||||
}
|
||||
h.stream = stream
|
||||
}
|
||||
|
||||
p := make([]byte, frameLimit)
|
||||
n, err := h.conn.Read(p)
|
||||
if err != nil {
|
||||
|
|
|
@ -25,6 +25,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
grpc "google.golang.org/grpc"
|
||||
core "google.golang.org/grpc/credentials/alts/internal"
|
||||
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
|
||||
|
@ -283,3 +285,65 @@ func (s) TestPeerNotResponding(t *testing.T) {
|
|||
t.Errorf("ClientHandshake() = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestNewClientHandshaker(t *testing.T) {
|
||||
conn := testutil.NewTestConn(nil, nil)
|
||||
clientConn := &grpc.ClientConn{}
|
||||
opts := &ClientHandshakerOptions{}
|
||||
hs, err := NewClientHandshaker(context.Background(), clientConn, conn, opts)
|
||||
if err != nil {
|
||||
t.Errorf("NewClientHandshaker returned unexpected error: %v", err)
|
||||
}
|
||||
expectedHs := &altsHandshaker{
|
||||
stream: nil,
|
||||
conn: conn,
|
||||
clientConn: clientConn,
|
||||
clientOpts: opts,
|
||||
serverOpts: nil,
|
||||
side: core.ClientSide,
|
||||
}
|
||||
cmpOpts := []cmp.Option{
|
||||
cmp.AllowUnexported(altsHandshaker{}),
|
||||
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
|
||||
}
|
||||
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
|
||||
t.Errorf("NewClientHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
|
||||
}
|
||||
if hs.(*altsHandshaker).stream != nil {
|
||||
t.Errorf("NewClientHandshaker() returned handshaker with non-nil stream")
|
||||
}
|
||||
if hs.(*altsHandshaker).clientConn != clientConn {
|
||||
t.Errorf("NewClientHandshaker() returned handshaker with unexpected clientConn")
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestNewServerHandshaker(t *testing.T) {
|
||||
conn := testutil.NewTestConn(nil, nil)
|
||||
clientConn := &grpc.ClientConn{}
|
||||
opts := &ServerHandshakerOptions{}
|
||||
hs, err := NewServerHandshaker(context.Background(), clientConn, conn, opts)
|
||||
if err != nil {
|
||||
t.Errorf("NewServerHandshaker returned unexpected error: %v", err)
|
||||
}
|
||||
expectedHs := &altsHandshaker{
|
||||
stream: nil,
|
||||
conn: conn,
|
||||
clientConn: clientConn,
|
||||
clientOpts: nil,
|
||||
serverOpts: opts,
|
||||
side: core.ServerSide,
|
||||
}
|
||||
cmpOpts := []cmp.Option{
|
||||
cmp.AllowUnexported(altsHandshaker{}),
|
||||
cmpopts.IgnoreFields(altsHandshaker{}, "conn", "clientConn"),
|
||||
}
|
||||
if got, want := hs.(*altsHandshaker), expectedHs; !cmp.Equal(got, want, cmpOpts...) {
|
||||
t.Errorf("NewServerHandshaker() returned unexpected handshaker: got: %v, want: %v", got, want)
|
||||
}
|
||||
if hs.(*altsHandshaker).stream != nil {
|
||||
t.Errorf("NewServerHandshaker() returned handshaker with non-nil stream")
|
||||
}
|
||||
if hs.(*altsHandshaker).clientConn != clientConn {
|
||||
t.Errorf("NewServerHandshaker() returned handshaker with unexpected clientConn")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue