mirror of https://github.com/grpc/grpc-go.git
[alts] Add plumbing for the bound access token field in the ALTS StartClient request. (#8284)
This commit is contained in:
parent
763d093ac8
commit
00be1e1383
|
@ -133,10 +133,11 @@ func DefaultServerOptions() *ServerOptions {
|
|||
// altsTC is the credentials required for authenticating a connection using ALTS.
|
||||
// It implements credentials.TransportCredentials interface.
|
||||
type altsTC struct {
|
||||
info *credentials.ProtocolInfo
|
||||
side core.Side
|
||||
accounts []string
|
||||
hsAddress string
|
||||
info *credentials.ProtocolInfo
|
||||
side core.Side
|
||||
accounts []string
|
||||
hsAddress string
|
||||
boundAccessToken string
|
||||
}
|
||||
|
||||
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
|
||||
|
@ -198,6 +199,7 @@ func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.C
|
|||
MaxRpcVersion: maxRPCVersion,
|
||||
MinRpcVersion: minRPCVersion,
|
||||
}
|
||||
opts.BoundAccessToken = g.boundAccessToken
|
||||
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
|
|
@ -336,6 +336,28 @@ func (s) TestFullHandshake(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestHandshakeWithAccessToken performs an ALTS handshake between a test client and
|
||||
// server, where both client and server offload to a local, fake handshaker
|
||||
// service, and expects the StartClient request to include a bound access token.
|
||||
func (s) TestHandshakeWithAccessToken(t *testing.T) {
|
||||
// Start the fake handshaker service and the server.
|
||||
var wait sync.WaitGroup
|
||||
defer wait.Wait()
|
||||
boundAccessToken := "fake-bound-access-token"
|
||||
stopHandshaker, handshakerAddress := startFakeHandshakerServiceWithExpectedBoundAccessToken(t, &wait, boundAccessToken)
|
||||
defer stopHandshaker()
|
||||
stopServer, serverAddress := startServer(t, handshakerAddress)
|
||||
defer stopServer()
|
||||
|
||||
// Ping the server, authenticating with ALTS and a bound access token.
|
||||
establishAltsConnectionWithBoundAccessToken(t, handshakerAddress, serverAddress, boundAccessToken)
|
||||
|
||||
// Close open connections to the fake handshaker service.
|
||||
if err := service.CloseForTesting(); err != nil {
|
||||
t.Errorf("service.CloseForTesting() failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentHandshakes performs a several, concurrent ALTS handshakes
|
||||
// between a test client and server, where both client and server offload to a
|
||||
// local, fake handshaker service.
|
||||
|
@ -385,7 +407,15 @@ func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocol
|
|||
}
|
||||
|
||||
func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) {
|
||||
establishAltsConnectionWithBoundAccessToken(t, handshakerAddress, serverAddress, "")
|
||||
}
|
||||
|
||||
func establishAltsConnectionWithBoundAccessToken(t *testing.T, handshakerAddress, serverAddress, boundAccessToken string) {
|
||||
clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress})
|
||||
if boundAccessToken != "" {
|
||||
altsCreds := clientCreds.(*altsTC)
|
||||
altsCreds.boundAccessToken = boundAccessToken
|
||||
}
|
||||
conn, err := grpc.NewClient(serverAddress, grpc.WithTransportCredentials(clientCreds))
|
||||
if err != nil {
|
||||
t.Fatalf("grpc.NewClient(%v) failed: %v", serverAddress, err)
|
||||
|
@ -429,12 +459,20 @@ func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress stri
|
|||
}
|
||||
|
||||
func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {
|
||||
return startFakeHandshakerServiceWithExpectedBoundAccessToken(t, wait, "")
|
||||
}
|
||||
|
||||
func startFakeHandshakerServiceWithExpectedBoundAccessToken(t *testing.T, wait *sync.WaitGroup, boundAccessToken string) (stop func(), address string) {
|
||||
listener, err := testutils.LocalTCPListener()
|
||||
if err != nil {
|
||||
t.Fatalf("LocalTCPListener() failed: %v", err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
altsgrpc.RegisterHandshakerServiceServer(s, &testutil.FakeHandshaker{})
|
||||
hs := &testutil.FakeHandshaker{}
|
||||
if boundAccessToken != "" {
|
||||
hs.ExpectedBoundAccessToken = boundAccessToken
|
||||
}
|
||||
altsgrpc.RegisterHandshakerServiceServer(s, hs)
|
||||
wait.Add(1)
|
||||
go func() {
|
||||
defer wait.Done()
|
||||
|
|
|
@ -88,6 +88,8 @@ type ClientHandshakerOptions struct {
|
|||
TargetServiceAccounts []string
|
||||
// RPCVersions specifies the gRPC versions accepted by the client.
|
||||
RPCVersions *altspb.RpcProtocolVersions
|
||||
// BoundAccessToken is a bound access token to be sent to the server for authentication.
|
||||
BoundAccessToken string
|
||||
}
|
||||
|
||||
// ServerHandshakerOptions contains the server handshaker options that can
|
||||
|
@ -195,7 +197,9 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
|
|||
},
|
||||
},
|
||||
}
|
||||
|
||||
if h.clientOpts.BoundAccessToken != "" {
|
||||
req.GetClientStart().AccessToken = h.clientOpts.BoundAccessToken
|
||||
}
|
||||
conn, result, err := h.doHandshake(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
|
|
@ -145,6 +145,8 @@ func MakeFrame(pl string) []byte {
|
|||
// FakeHandshaker is a fake implementation of the ALTS handshaker service.
|
||||
type FakeHandshaker struct {
|
||||
altsgrpc.HandshakerServiceServer
|
||||
// ExpectedBoundAccessToken is the expected bound access token in the ClientStart request.
|
||||
ExpectedBoundAccessToken string
|
||||
}
|
||||
|
||||
// DoHandshake performs a fake ALTS handshake.
|
||||
|
@ -221,6 +223,9 @@ func (h *FakeHandshaker) processStartClient(req *altspb.StartClientHandshakeReq)
|
|||
if len(req.RecordProtocols) != 1 || req.RecordProtocols[0] != "ALTSRP_GCM_AES128_REKEY" {
|
||||
return nil, fmt.Errorf("unexpected record protocols: %v", req.RecordProtocols)
|
||||
}
|
||||
if h.ExpectedBoundAccessToken != req.GetAccessToken() {
|
||||
return nil, fmt.Errorf("unexpected access token: %v", req.GetAccessToken())
|
||||
}
|
||||
return &altspb.HandshakerResp{
|
||||
OutFrames: []byte("ClientInit"),
|
||||
BytesConsumed: 0,
|
||||
|
|
Loading…
Reference in New Issue