[alts] Add plumbing for the bound access token field in the ALTS StartClient request. (#8284)

This commit is contained in:
Matthew Stevenson 2025-05-05 08:07:34 -07:00 committed by GitHub
parent 763d093ac8
commit 00be1e1383
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 55 additions and 6 deletions

View File

@ -133,10 +133,11 @@ func DefaultServerOptions() *ServerOptions {
// altsTC is the credentials required for authenticating a connection using ALTS.
// It implements credentials.TransportCredentials interface.
type altsTC struct {
info *credentials.ProtocolInfo
side core.Side
accounts []string
hsAddress string
info *credentials.ProtocolInfo
side core.Side
accounts []string
hsAddress string
boundAccessToken string
}
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
@ -198,6 +199,7 @@ func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.C
MaxRpcVersion: maxRPCVersion,
MinRpcVersion: minRPCVersion,
}
opts.BoundAccessToken = g.boundAccessToken
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
if err != nil {
return nil, nil, err

View File

@ -336,6 +336,28 @@ func (s) TestFullHandshake(t *testing.T) {
}
}
// TestHandshakeWithAccessToken performs an ALTS handshake between a test client and
// server, where both client and server offload to a local, fake handshaker
// service, and expects the StartClient request to include a bound access token.
func (s) TestHandshakeWithAccessToken(t *testing.T) {
// Start the fake handshaker service and the server.
var wait sync.WaitGroup
defer wait.Wait()
boundAccessToken := "fake-bound-access-token"
stopHandshaker, handshakerAddress := startFakeHandshakerServiceWithExpectedBoundAccessToken(t, &wait, boundAccessToken)
defer stopHandshaker()
stopServer, serverAddress := startServer(t, handshakerAddress)
defer stopServer()
// Ping the server, authenticating with ALTS and a bound access token.
establishAltsConnectionWithBoundAccessToken(t, handshakerAddress, serverAddress, boundAccessToken)
// Close open connections to the fake handshaker service.
if err := service.CloseForTesting(); err != nil {
t.Errorf("service.CloseForTesting() failed: %v", err)
}
}
// TestConcurrentHandshakes performs a several, concurrent ALTS handshakes
// between a test client and server, where both client and server offload to a
// local, fake handshaker service.
@ -385,7 +407,15 @@ func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocol
}
func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) {
establishAltsConnectionWithBoundAccessToken(t, handshakerAddress, serverAddress, "")
}
func establishAltsConnectionWithBoundAccessToken(t *testing.T, handshakerAddress, serverAddress, boundAccessToken string) {
clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress})
if boundAccessToken != "" {
altsCreds := clientCreds.(*altsTC)
altsCreds.boundAccessToken = boundAccessToken
}
conn, err := grpc.NewClient(serverAddress, grpc.WithTransportCredentials(clientCreds))
if err != nil {
t.Fatalf("grpc.NewClient(%v) failed: %v", serverAddress, err)
@ -429,12 +459,20 @@ func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress stri
}
func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {
return startFakeHandshakerServiceWithExpectedBoundAccessToken(t, wait, "")
}
func startFakeHandshakerServiceWithExpectedBoundAccessToken(t *testing.T, wait *sync.WaitGroup, boundAccessToken string) (stop func(), address string) {
listener, err := testutils.LocalTCPListener()
if err != nil {
t.Fatalf("LocalTCPListener() failed: %v", err)
}
s := grpc.NewServer()
altsgrpc.RegisterHandshakerServiceServer(s, &testutil.FakeHandshaker{})
hs := &testutil.FakeHandshaker{}
if boundAccessToken != "" {
hs.ExpectedBoundAccessToken = boundAccessToken
}
altsgrpc.RegisterHandshakerServiceServer(s, hs)
wait.Add(1)
go func() {
defer wait.Done()

View File

@ -88,6 +88,8 @@ type ClientHandshakerOptions struct {
TargetServiceAccounts []string
// RPCVersions specifies the gRPC versions accepted by the client.
RPCVersions *altspb.RpcProtocolVersions
// BoundAccessToken is a bound access token to be sent to the server for authentication.
BoundAccessToken string
}
// ServerHandshakerOptions contains the server handshaker options that can
@ -195,7 +197,9 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
},
},
}
if h.clientOpts.BoundAccessToken != "" {
req.GetClientStart().AccessToken = h.clientOpts.BoundAccessToken
}
conn, result, err := h.doHandshake(req)
if err != nil {
return nil, nil, err

View File

@ -145,6 +145,8 @@ func MakeFrame(pl string) []byte {
// FakeHandshaker is a fake implementation of the ALTS handshaker service.
type FakeHandshaker struct {
altsgrpc.HandshakerServiceServer
// ExpectedBoundAccessToken is the expected bound access token in the ClientStart request.
ExpectedBoundAccessToken string
}
// DoHandshake performs a fake ALTS handshake.
@ -221,6 +223,9 @@ func (h *FakeHandshaker) processStartClient(req *altspb.StartClientHandshakeReq)
if len(req.RecordProtocols) != 1 || req.RecordProtocols[0] != "ALTSRP_GCM_AES128_REKEY" {
return nil, fmt.Errorf("unexpected record protocols: %v", req.RecordProtocols)
}
if h.ExpectedBoundAccessToken != req.GetAccessToken() {
return nil, fmt.Errorf("unexpected access token: %v", req.GetAccessToken())
}
return &altspb.HandshakerResp{
OutFrames: []byte("ClientInit"),
BytesConsumed: 0,