diff --git a/credentials/alts/alts_test.go b/credentials/alts/alts_test.go index aef9642f8..9a95d4628 100644 --- a/credentials/alts/alts_test.go +++ b/credentials/alts/alts_test.go @@ -24,7 +24,6 @@ package alts import ( "context" "reflect" - "runtime" "sync" "testing" "time" @@ -309,21 +308,12 @@ func (s) TestCheckRPCVersions(t *testing.T) { // server, where both client and server offload to a local, fake handshaker // service. func (s) TestFullHandshake(t *testing.T) { - // If GOMAXPROCS is set to less than 2, do not run this test. This test - // requires at least 2 goroutines to succeed (one goroutine where a - // server listens, another goroutine where a client runs). - if runtime.GOMAXPROCS(0) < 2 { - return - } - // The vmOnGCP global variable MUST be reset to true after the client // or server credentials have been created, but before the ALTS // handshake begins. If vmOnGCP is not reset and this test is run // anywhere except for a GCP VM, then the ALTS handshake will // immediately fail. - once.Do(func() { - vmOnGCP = true - }) + once.Do(func() {}) vmOnGCP = true // Start the fake handshaker service and the server. diff --git a/credentials/alts/internal/testutil/testutil.go b/credentials/alts/internal/testutil/testutil.go index 24a61202a..cdc88c8f9 100644 --- a/credentials/alts/internal/testutil/testutil.go +++ b/credentials/alts/internal/testutil/testutil.go @@ -136,6 +136,7 @@ type FakeHandshaker struct { // DoHandshake performs a fake ALTS handshake. func (h *FakeHandshaker) DoHandshake(stream altsgrpc.HandshakerService_DoHandshakeServer) error { var isAssistingClient bool + var handshakeFramesReceivedSoFar []byte for { req, err := stream.Recv() if err != nil { @@ -153,15 +154,38 @@ func (h *FakeHandshaker) DoHandshake(stream altsgrpc.HandshakerService_DoHandsha return fmt.Errorf("processStartClient failure: %v", err) } case *altspb.HandshakerReq_ServerStart: + // If we have received the full ClientInit, send the ServerInit and + // ServerFinished. Otherwise, wait for more bytes to arrive from the client. isAssistingClient = false - resp, err = h.processServerStart(req.ServerStart) + handshakeFramesReceivedSoFar = append(handshakeFramesReceivedSoFar, req.ServerStart.InBytes...) + sendHandshakeFrame := bytes.Equal(handshakeFramesReceivedSoFar, []byte("ClientInit")) + resp, err = h.processServerStart(req.ServerStart, sendHandshakeFrame) if err != nil { - return fmt.Errorf("processServerClient failure: %v", err) + return fmt.Errorf("processServerStart failure: %v", err) } case *altspb.HandshakerReq_Next: - resp, err = h.processNext(req.Next, isAssistingClient) + // If we have received all handshake frames, send the handshake result. + // Otherwise, wait for more bytes to arrive from the peer. + oldHandshakesBytes := len(handshakeFramesReceivedSoFar) + handshakeFramesReceivedSoFar = append(handshakeFramesReceivedSoFar, req.Next.InBytes...) + isHandshakeComplete := false + if isAssistingClient { + isHandshakeComplete = bytes.HasPrefix(handshakeFramesReceivedSoFar, []byte("ServerInitServerFinished")) + } else { + isHandshakeComplete = bytes.HasPrefix(handshakeFramesReceivedSoFar, []byte("ClientInitClientFinished")) + } + if !isHandshakeComplete { + resp = &altspb.HandshakerResp{ + BytesConsumed: uint32(len(handshakeFramesReceivedSoFar) - oldHandshakesBytes), + Status: &altspb.HandshakerStatus{ + Code: uint32(codes.OK), + }, + } + break + } + resp, err = h.getHandshakeResult(isAssistingClient) if err != nil { - return fmt.Errorf("processNext failure: %v", err) + return fmt.Errorf("getHandshakeResult failure: %v", err) } default: return fmt.Errorf("handshake request has unexpected type: %v", req) @@ -192,7 +216,7 @@ func (h *FakeHandshaker) processStartClient(req *altspb.StartClientHandshakeReq) }, nil } -func (h *FakeHandshaker) processServerStart(req *altspb.StartServerHandshakeReq) (*altspb.HandshakerResp, error) { +func (h *FakeHandshaker) processServerStart(req *altspb.StartServerHandshakeReq, sendHandshakeFrame bool) (*altspb.HandshakerResp, error) { if len(req.ApplicationProtocols) != 1 || req.ApplicationProtocols[0] != "grpc" { return nil, fmt.Errorf("unexpected application protocols: %v", req.ApplicationProtocols) } @@ -203,8 +227,14 @@ func (h *FakeHandshaker) processServerStart(req *altspb.StartServerHandshakeReq) if len(parameters.RecordProtocols) != 1 || parameters.RecordProtocols[0] != "ALTSRP_GCM_AES128_REKEY" { return nil, fmt.Errorf("unexpected record protocols: %v", parameters.RecordProtocols) } - if string(req.InBytes) != "ClientInit" { - return nil, fmt.Errorf("unexpected in bytes: %v", req.InBytes) + if sendHandshakeFrame { + return &altspb.HandshakerResp{ + OutFrames: []byte("ServerInitServerFinished"), + BytesConsumed: uint32(len(req.InBytes)), + Status: &altspb.HandshakerStatus{ + Code: uint32(codes.OK), + }, + }, nil } return &altspb.HandshakerResp{ OutFrames: []byte("ServerInitServerFinished"), @@ -215,11 +245,8 @@ func (h *FakeHandshaker) processServerStart(req *altspb.StartServerHandshakeReq) }, nil } -func (h *FakeHandshaker) processNext(req *altspb.NextHandshakeMessageReq, isAssistingClient bool) (*altspb.HandshakerResp, error) { +func (h *FakeHandshaker) getHandshakeResult(isAssistingClient bool) (*altspb.HandshakerResp, error) { if isAssistingClient { - if !bytes.Equal(req.InBytes, []byte("ServerInitServerFinished")) { - return nil, fmt.Errorf("unexpected in bytes: got: %v, want: %v", req.InBytes, []byte("ServerInitServerFinished")) - } return &altspb.HandshakerResp{ OutFrames: []byte("ClientFinished"), BytesConsumed: 24, @@ -248,9 +275,6 @@ func (h *FakeHandshaker) processNext(req *altspb.NextHandshakeMessageReq, isAssi }, }, nil } - if !bytes.Equal(req.InBytes, []byte("ClientFinished")) { - return nil, fmt.Errorf("unexpected in bytes: got: %v, want: %v", req.InBytes, []byte("ClientFinished")) - } return &altspb.HandshakerResp{ BytesConsumed: 14, Result: &altspb.HandshakerResult{