alts: Fix flaky ALTS TestFullHandshake test. (#6300)

* Fix flaky ALTS FullHandshake test.

* Fix one other flake possibility.

* fix typo in comment

* Wait for full handshake frames to arrive from peer.

* Remove runtime.GOMAXPROCS from the test.

* Only set vmOnGCP once.
This commit is contained in:
Matthew Stevenson 2023-05-25 15:05:50 -07:00 committed by GitHub
parent 4d3f221d1d
commit e325737cac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 25 deletions

View File

@ -24,7 +24,6 @@ package alts
import (
"context"
"reflect"
"runtime"
"sync"
"testing"
"time"
@ -309,21 +308,12 @@ func (s) TestCheckRPCVersions(t *testing.T) {
// server, where both client and server offload to a local, fake handshaker
// service.
func (s) TestFullHandshake(t *testing.T) {
// If GOMAXPROCS is set to less than 2, do not run this test. This test
// requires at least 2 goroutines to succeed (one goroutine where a
// server listens, another goroutine where a client runs).
if runtime.GOMAXPROCS(0) < 2 {
return
}
// The vmOnGCP global variable MUST be reset to true after the client
// or server credentials have been created, but before the ALTS
// handshake begins. If vmOnGCP is not reset and this test is run
// anywhere except for a GCP VM, then the ALTS handshake will
// immediately fail.
once.Do(func() {
vmOnGCP = true
})
once.Do(func() {})
vmOnGCP = true
// Start the fake handshaker service and the server.

View File

@ -136,6 +136,7 @@ type FakeHandshaker struct {
// DoHandshake performs a fake ALTS handshake.
func (h *FakeHandshaker) DoHandshake(stream altsgrpc.HandshakerService_DoHandshakeServer) error {
var isAssistingClient bool
var handshakeFramesReceivedSoFar []byte
for {
req, err := stream.Recv()
if err != nil {
@ -153,15 +154,38 @@ func (h *FakeHandshaker) DoHandshake(stream altsgrpc.HandshakerService_DoHandsha
return fmt.Errorf("processStartClient failure: %v", err)
}
case *altspb.HandshakerReq_ServerStart:
// If we have received the full ClientInit, send the ServerInit and
// ServerFinished. Otherwise, wait for more bytes to arrive from the client.
isAssistingClient = false
resp, err = h.processServerStart(req.ServerStart)
handshakeFramesReceivedSoFar = append(handshakeFramesReceivedSoFar, req.ServerStart.InBytes...)
sendHandshakeFrame := bytes.Equal(handshakeFramesReceivedSoFar, []byte("ClientInit"))
resp, err = h.processServerStart(req.ServerStart, sendHandshakeFrame)
if err != nil {
return fmt.Errorf("processServerClient failure: %v", err)
return fmt.Errorf("processServerStart failure: %v", err)
}
case *altspb.HandshakerReq_Next:
resp, err = h.processNext(req.Next, isAssistingClient)
// If we have received all handshake frames, send the handshake result.
// Otherwise, wait for more bytes to arrive from the peer.
oldHandshakesBytes := len(handshakeFramesReceivedSoFar)
handshakeFramesReceivedSoFar = append(handshakeFramesReceivedSoFar, req.Next.InBytes...)
isHandshakeComplete := false
if isAssistingClient {
isHandshakeComplete = bytes.HasPrefix(handshakeFramesReceivedSoFar, []byte("ServerInitServerFinished"))
} else {
isHandshakeComplete = bytes.HasPrefix(handshakeFramesReceivedSoFar, []byte("ClientInitClientFinished"))
}
if !isHandshakeComplete {
resp = &altspb.HandshakerResp{
BytesConsumed: uint32(len(handshakeFramesReceivedSoFar) - oldHandshakesBytes),
Status: &altspb.HandshakerStatus{
Code: uint32(codes.OK),
},
}
break
}
resp, err = h.getHandshakeResult(isAssistingClient)
if err != nil {
return fmt.Errorf("processNext failure: %v", err)
return fmt.Errorf("getHandshakeResult failure: %v", err)
}
default:
return fmt.Errorf("handshake request has unexpected type: %v", req)
@ -192,7 +216,7 @@ func (h *FakeHandshaker) processStartClient(req *altspb.StartClientHandshakeReq)
}, nil
}
func (h *FakeHandshaker) processServerStart(req *altspb.StartServerHandshakeReq) (*altspb.HandshakerResp, error) {
func (h *FakeHandshaker) processServerStart(req *altspb.StartServerHandshakeReq, sendHandshakeFrame bool) (*altspb.HandshakerResp, error) {
if len(req.ApplicationProtocols) != 1 || req.ApplicationProtocols[0] != "grpc" {
return nil, fmt.Errorf("unexpected application protocols: %v", req.ApplicationProtocols)
}
@ -203,8 +227,14 @@ func (h *FakeHandshaker) processServerStart(req *altspb.StartServerHandshakeReq)
if len(parameters.RecordProtocols) != 1 || parameters.RecordProtocols[0] != "ALTSRP_GCM_AES128_REKEY" {
return nil, fmt.Errorf("unexpected record protocols: %v", parameters.RecordProtocols)
}
if string(req.InBytes) != "ClientInit" {
return nil, fmt.Errorf("unexpected in bytes: %v", req.InBytes)
if sendHandshakeFrame {
return &altspb.HandshakerResp{
OutFrames: []byte("ServerInitServerFinished"),
BytesConsumed: uint32(len(req.InBytes)),
Status: &altspb.HandshakerStatus{
Code: uint32(codes.OK),
},
}, nil
}
return &altspb.HandshakerResp{
OutFrames: []byte("ServerInitServerFinished"),
@ -215,11 +245,8 @@ func (h *FakeHandshaker) processServerStart(req *altspb.StartServerHandshakeReq)
}, nil
}
func (h *FakeHandshaker) processNext(req *altspb.NextHandshakeMessageReq, isAssistingClient bool) (*altspb.HandshakerResp, error) {
func (h *FakeHandshaker) getHandshakeResult(isAssistingClient bool) (*altspb.HandshakerResp, error) {
if isAssistingClient {
if !bytes.Equal(req.InBytes, []byte("ServerInitServerFinished")) {
return nil, fmt.Errorf("unexpected in bytes: got: %v, want: %v", req.InBytes, []byte("ServerInitServerFinished"))
}
return &altspb.HandshakerResp{
OutFrames: []byte("ClientFinished"),
BytesConsumed: 24,
@ -248,9 +275,6 @@ func (h *FakeHandshaker) processNext(req *altspb.NextHandshakeMessageReq, isAssi
},
}, nil
}
if !bytes.Equal(req.InBytes, []byte("ClientFinished")) {
return nil, fmt.Errorf("unexpected in bytes: got: %v, want: %v", req.InBytes, []byte("ClientFinished"))
}
return &altspb.HandshakerResp{
BytesConsumed: 14,
Result: &altspb.HandshakerResult{