mirror of https://github.com/grpc/grpc-go.git
[alts] Add plumbing for the bound access token field in the ALTS StartClient request. (#8284)
This commit is contained in:
parent
763d093ac8
commit
00be1e1383
|
@ -133,10 +133,11 @@ func DefaultServerOptions() *ServerOptions {
|
||||||
// altsTC is the credentials required for authenticating a connection using ALTS.
|
// altsTC is the credentials required for authenticating a connection using ALTS.
|
||||||
// It implements credentials.TransportCredentials interface.
|
// It implements credentials.TransportCredentials interface.
|
||||||
type altsTC struct {
|
type altsTC struct {
|
||||||
info *credentials.ProtocolInfo
|
info *credentials.ProtocolInfo
|
||||||
side core.Side
|
side core.Side
|
||||||
accounts []string
|
accounts []string
|
||||||
hsAddress string
|
hsAddress string
|
||||||
|
boundAccessToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
|
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
|
||||||
|
@ -198,6 +199,7 @@ func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.C
|
||||||
MaxRpcVersion: maxRPCVersion,
|
MaxRpcVersion: maxRPCVersion,
|
||||||
MinRpcVersion: minRPCVersion,
|
MinRpcVersion: minRPCVersion,
|
||||||
}
|
}
|
||||||
|
opts.BoundAccessToken = g.boundAccessToken
|
||||||
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
|
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|
|
@ -336,6 +336,28 @@ func (s) TestFullHandshake(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestHandshakeWithAccessToken performs an ALTS handshake between a test client and
|
||||||
|
// server, where both client and server offload to a local, fake handshaker
|
||||||
|
// service, and expects the StartClient request to include a bound access token.
|
||||||
|
func (s) TestHandshakeWithAccessToken(t *testing.T) {
|
||||||
|
// Start the fake handshaker service and the server.
|
||||||
|
var wait sync.WaitGroup
|
||||||
|
defer wait.Wait()
|
||||||
|
boundAccessToken := "fake-bound-access-token"
|
||||||
|
stopHandshaker, handshakerAddress := startFakeHandshakerServiceWithExpectedBoundAccessToken(t, &wait, boundAccessToken)
|
||||||
|
defer stopHandshaker()
|
||||||
|
stopServer, serverAddress := startServer(t, handshakerAddress)
|
||||||
|
defer stopServer()
|
||||||
|
|
||||||
|
// Ping the server, authenticating with ALTS and a bound access token.
|
||||||
|
establishAltsConnectionWithBoundAccessToken(t, handshakerAddress, serverAddress, boundAccessToken)
|
||||||
|
|
||||||
|
// Close open connections to the fake handshaker service.
|
||||||
|
if err := service.CloseForTesting(); err != nil {
|
||||||
|
t.Errorf("service.CloseForTesting() failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestConcurrentHandshakes performs a several, concurrent ALTS handshakes
|
// TestConcurrentHandshakes performs a several, concurrent ALTS handshakes
|
||||||
// between a test client and server, where both client and server offload to a
|
// between a test client and server, where both client and server offload to a
|
||||||
// local, fake handshaker service.
|
// local, fake handshaker service.
|
||||||
|
@ -385,7 +407,15 @@ func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocol
|
||||||
}
|
}
|
||||||
|
|
||||||
func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) {
|
func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) {
|
||||||
|
establishAltsConnectionWithBoundAccessToken(t, handshakerAddress, serverAddress, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func establishAltsConnectionWithBoundAccessToken(t *testing.T, handshakerAddress, serverAddress, boundAccessToken string) {
|
||||||
clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress})
|
clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress})
|
||||||
|
if boundAccessToken != "" {
|
||||||
|
altsCreds := clientCreds.(*altsTC)
|
||||||
|
altsCreds.boundAccessToken = boundAccessToken
|
||||||
|
}
|
||||||
conn, err := grpc.NewClient(serverAddress, grpc.WithTransportCredentials(clientCreds))
|
conn, err := grpc.NewClient(serverAddress, grpc.WithTransportCredentials(clientCreds))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("grpc.NewClient(%v) failed: %v", serverAddress, err)
|
t.Fatalf("grpc.NewClient(%v) failed: %v", serverAddress, err)
|
||||||
|
@ -429,12 +459,20 @@ func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress stri
|
||||||
}
|
}
|
||||||
|
|
||||||
func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {
|
func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {
|
||||||
|
return startFakeHandshakerServiceWithExpectedBoundAccessToken(t, wait, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func startFakeHandshakerServiceWithExpectedBoundAccessToken(t *testing.T, wait *sync.WaitGroup, boundAccessToken string) (stop func(), address string) {
|
||||||
listener, err := testutils.LocalTCPListener()
|
listener, err := testutils.LocalTCPListener()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("LocalTCPListener() failed: %v", err)
|
t.Fatalf("LocalTCPListener() failed: %v", err)
|
||||||
}
|
}
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
altsgrpc.RegisterHandshakerServiceServer(s, &testutil.FakeHandshaker{})
|
hs := &testutil.FakeHandshaker{}
|
||||||
|
if boundAccessToken != "" {
|
||||||
|
hs.ExpectedBoundAccessToken = boundAccessToken
|
||||||
|
}
|
||||||
|
altsgrpc.RegisterHandshakerServiceServer(s, hs)
|
||||||
wait.Add(1)
|
wait.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wait.Done()
|
defer wait.Done()
|
||||||
|
|
|
@ -88,6 +88,8 @@ type ClientHandshakerOptions struct {
|
||||||
TargetServiceAccounts []string
|
TargetServiceAccounts []string
|
||||||
// RPCVersions specifies the gRPC versions accepted by the client.
|
// RPCVersions specifies the gRPC versions accepted by the client.
|
||||||
RPCVersions *altspb.RpcProtocolVersions
|
RPCVersions *altspb.RpcProtocolVersions
|
||||||
|
// BoundAccessToken is a bound access token to be sent to the server for authentication.
|
||||||
|
BoundAccessToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerHandshakerOptions contains the server handshaker options that can
|
// ServerHandshakerOptions contains the server handshaker options that can
|
||||||
|
@ -195,7 +197,9 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
if h.clientOpts.BoundAccessToken != "" {
|
||||||
|
req.GetClientStart().AccessToken = h.clientOpts.BoundAccessToken
|
||||||
|
}
|
||||||
conn, result, err := h.doHandshake(req)
|
conn, result, err := h.doHandshake(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|
|
@ -145,6 +145,8 @@ func MakeFrame(pl string) []byte {
|
||||||
// FakeHandshaker is a fake implementation of the ALTS handshaker service.
|
// FakeHandshaker is a fake implementation of the ALTS handshaker service.
|
||||||
type FakeHandshaker struct {
|
type FakeHandshaker struct {
|
||||||
altsgrpc.HandshakerServiceServer
|
altsgrpc.HandshakerServiceServer
|
||||||
|
// ExpectedBoundAccessToken is the expected bound access token in the ClientStart request.
|
||||||
|
ExpectedBoundAccessToken string
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoHandshake performs a fake ALTS handshake.
|
// DoHandshake performs a fake ALTS handshake.
|
||||||
|
@ -221,6 +223,9 @@ func (h *FakeHandshaker) processStartClient(req *altspb.StartClientHandshakeReq)
|
||||||
if len(req.RecordProtocols) != 1 || req.RecordProtocols[0] != "ALTSRP_GCM_AES128_REKEY" {
|
if len(req.RecordProtocols) != 1 || req.RecordProtocols[0] != "ALTSRP_GCM_AES128_REKEY" {
|
||||||
return nil, fmt.Errorf("unexpected record protocols: %v", req.RecordProtocols)
|
return nil, fmt.Errorf("unexpected record protocols: %v", req.RecordProtocols)
|
||||||
}
|
}
|
||||||
|
if h.ExpectedBoundAccessToken != req.GetAccessToken() {
|
||||||
|
return nil, fmt.Errorf("unexpected access token: %v", req.GetAccessToken())
|
||||||
|
}
|
||||||
return &altspb.HandshakerResp{
|
return &altspb.HandshakerResp{
|
||||||
OutFrames: []byte("ClientInit"),
|
OutFrames: []byte("ClientInit"),
|
||||||
BytesConsumed: 0,
|
BytesConsumed: 0,
|
||||||
|
|
Loading…
Reference in New Issue