From 00be1e13830fd14c5c52cf021eb3a5370e7f782d Mon Sep 17 00:00:00 2001 From: Matthew Stevenson <52979934+matthewstevenson88@users.noreply.github.com> Date: Mon, 5 May 2025 08:07:34 -0700 Subject: [PATCH] [alts] Add plumbing for the bound access token field in the ALTS StartClient request. (#8284) --- credentials/alts/alts.go | 10 +++-- credentials/alts/alts_test.go | 40 ++++++++++++++++++- .../alts/internal/handshaker/handshaker.go | 6 ++- .../alts/internal/testutil/testutil.go | 5 +++ 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/credentials/alts/alts.go b/credentials/alts/alts.go index afcdb8a0d..35539eb1a 100644 --- a/credentials/alts/alts.go +++ b/credentials/alts/alts.go @@ -133,10 +133,11 @@ func DefaultServerOptions() *ServerOptions { // altsTC is the credentials required for authenticating a connection using ALTS. // It implements credentials.TransportCredentials interface. type altsTC struct { - info *credentials.ProtocolInfo - side core.Side - accounts []string - hsAddress string + info *credentials.ProtocolInfo + side core.Side + accounts []string + hsAddress string + boundAccessToken string } // NewClientCreds constructs a client-side ALTS TransportCredentials object. @@ -198,6 +199,7 @@ func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.C MaxRpcVersion: maxRPCVersion, MinRpcVersion: minRPCVersion, } + opts.BoundAccessToken = g.boundAccessToken chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts) if err != nil { return nil, nil, err diff --git a/credentials/alts/alts_test.go b/credentials/alts/alts_test.go index 48f871a00..4710a32b1 100644 --- a/credentials/alts/alts_test.go +++ b/credentials/alts/alts_test.go @@ -336,6 +336,28 @@ func (s) TestFullHandshake(t *testing.T) { } } +// TestHandshakeWithAccessToken performs an ALTS handshake between a test client and +// server, where both client and server offload to a local, fake handshaker +// service, and expects the StartClient request to include a bound access token. +func (s) TestHandshakeWithAccessToken(t *testing.T) { + // Start the fake handshaker service and the server. + var wait sync.WaitGroup + defer wait.Wait() + boundAccessToken := "fake-bound-access-token" + stopHandshaker, handshakerAddress := startFakeHandshakerServiceWithExpectedBoundAccessToken(t, &wait, boundAccessToken) + defer stopHandshaker() + stopServer, serverAddress := startServer(t, handshakerAddress) + defer stopServer() + + // Ping the server, authenticating with ALTS and a bound access token. + establishAltsConnectionWithBoundAccessToken(t, handshakerAddress, serverAddress, boundAccessToken) + + // Close open connections to the fake handshaker service. + if err := service.CloseForTesting(); err != nil { + t.Errorf("service.CloseForTesting() failed: %v", err) + } +} + // TestConcurrentHandshakes performs a several, concurrent ALTS handshakes // between a test client and server, where both client and server offload to a // local, fake handshaker service. @@ -385,7 +407,15 @@ func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocol } func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) { + establishAltsConnectionWithBoundAccessToken(t, handshakerAddress, serverAddress, "") +} + +func establishAltsConnectionWithBoundAccessToken(t *testing.T, handshakerAddress, serverAddress, boundAccessToken string) { clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress}) + if boundAccessToken != "" { + altsCreds := clientCreds.(*altsTC) + altsCreds.boundAccessToken = boundAccessToken + } conn, err := grpc.NewClient(serverAddress, grpc.WithTransportCredentials(clientCreds)) if err != nil { t.Fatalf("grpc.NewClient(%v) failed: %v", serverAddress, err) @@ -429,12 +459,20 @@ func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress stri } func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) { + return startFakeHandshakerServiceWithExpectedBoundAccessToken(t, wait, "") +} + +func startFakeHandshakerServiceWithExpectedBoundAccessToken(t *testing.T, wait *sync.WaitGroup, boundAccessToken string) (stop func(), address string) { listener, err := testutils.LocalTCPListener() if err != nil { t.Fatalf("LocalTCPListener() failed: %v", err) } s := grpc.NewServer() - altsgrpc.RegisterHandshakerServiceServer(s, &testutil.FakeHandshaker{}) + hs := &testutil.FakeHandshaker{} + if boundAccessToken != "" { + hs.ExpectedBoundAccessToken = boundAccessToken + } + altsgrpc.RegisterHandshakerServiceServer(s, hs) wait.Add(1) go func() { defer wait.Done() diff --git a/credentials/alts/internal/handshaker/handshaker.go b/credentials/alts/internal/handshaker/handshaker.go index becd2f3bd..0360842eb 100644 --- a/credentials/alts/internal/handshaker/handshaker.go +++ b/credentials/alts/internal/handshaker/handshaker.go @@ -88,6 +88,8 @@ type ClientHandshakerOptions struct { TargetServiceAccounts []string // RPCVersions specifies the gRPC versions accepted by the client. RPCVersions *altspb.RpcProtocolVersions + // BoundAccessToken is a bound access token to be sent to the server for authentication. + BoundAccessToken string } // ServerHandshakerOptions contains the server handshaker options that can @@ -195,7 +197,9 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent }, }, } - + if h.clientOpts.BoundAccessToken != "" { + req.GetClientStart().AccessToken = h.clientOpts.BoundAccessToken + } conn, result, err := h.doHandshake(req) if err != nil { return nil, nil, err diff --git a/credentials/alts/internal/testutil/testutil.go b/credentials/alts/internal/testutil/testutil.go index 8ab94133f..1dfccdc06 100644 --- a/credentials/alts/internal/testutil/testutil.go +++ b/credentials/alts/internal/testutil/testutil.go @@ -145,6 +145,8 @@ func MakeFrame(pl string) []byte { // FakeHandshaker is a fake implementation of the ALTS handshaker service. type FakeHandshaker struct { altsgrpc.HandshakerServiceServer + // ExpectedBoundAccessToken is the expected bound access token in the ClientStart request. + ExpectedBoundAccessToken string } // DoHandshake performs a fake ALTS handshake. @@ -221,6 +223,9 @@ func (h *FakeHandshaker) processStartClient(req *altspb.StartClientHandshakeReq) if len(req.RecordProtocols) != 1 || req.RecordProtocols[0] != "ALTSRP_GCM_AES128_REKEY" { return nil, fmt.Errorf("unexpected record protocols: %v", req.RecordProtocols) } + if h.ExpectedBoundAccessToken != req.GetAccessToken() { + return nil, fmt.Errorf("unexpected access token: %v", req.GetAccessToken()) + } return &altspb.HandshakerResp{ OutFrames: []byte("ClientInit"), BytesConsumed: 0,