mirror of https://github.com/grpc/grpc-go.git
286 lines
7.6 KiB
Go
286 lines
7.6 KiB
Go
/*
|
|
*
|
|
* Copyright 2018 gRPC authors.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
*/
|
|
|
|
package handshaker
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
grpc "google.golang.org/grpc"
|
|
core "google.golang.org/grpc/credentials/alts/internal"
|
|
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
|
|
"google.golang.org/grpc/credentials/alts/internal/testutil"
|
|
"google.golang.org/grpc/internal/grpctest"
|
|
)
|
|
|
|
type s struct {
|
|
grpctest.Tester
|
|
}
|
|
|
|
func Test(t *testing.T) {
|
|
grpctest.RunSubTests(t, s{})
|
|
}
|
|
|
|
var (
|
|
testRecordProtocol = rekeyRecordProtocolName
|
|
testKey = []byte{
|
|
// 44 arbitrary bytes.
|
|
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49,
|
|
0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49, 0x1f, 0x8b,
|
|
0xd2, 0x4c, 0xce, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2,
|
|
}
|
|
testServiceAccount = "test_service_account"
|
|
testTargetServiceAccounts = []string{testServiceAccount}
|
|
testClientIdentity = &altspb.Identity{
|
|
IdentityOneof: &altspb.Identity_Hostname{
|
|
Hostname: "i_am_a_client",
|
|
},
|
|
}
|
|
)
|
|
|
|
const defaultTestTimeout = 10 * time.Second
|
|
|
|
// testRPCStream mimics a altspb.HandshakerService_DoHandshakeClient object.
|
|
type testRPCStream struct {
|
|
grpc.ClientStream
|
|
t *testing.T
|
|
isClient bool
|
|
// The resp expected to be returned by Recv(). Make sure this is set to
|
|
// the content the test requires before Recv() is invoked.
|
|
recvBuf *altspb.HandshakerResp
|
|
// false if it is the first access to Handshaker service on Envelope.
|
|
first bool
|
|
// useful for testing concurrent calls.
|
|
delay time.Duration
|
|
}
|
|
|
|
func (t *testRPCStream) Recv() (*altspb.HandshakerResp, error) {
|
|
resp := t.recvBuf
|
|
t.recvBuf = nil
|
|
return resp, nil
|
|
}
|
|
|
|
func (t *testRPCStream) Send(req *altspb.HandshakerReq) error {
|
|
var resp *altspb.HandshakerResp
|
|
if !t.first {
|
|
// Generate the bytes to be returned by Recv() for the initial
|
|
// handshaking.
|
|
t.first = true
|
|
if t.isClient {
|
|
resp = &altspb.HandshakerResp{
|
|
OutFrames: testutil.MakeFrame("ClientInit"),
|
|
// Simulate consuming ServerInit.
|
|
BytesConsumed: 14,
|
|
}
|
|
} else {
|
|
resp = &altspb.HandshakerResp{
|
|
OutFrames: testutil.MakeFrame("ServerInit"),
|
|
// Simulate consuming ClientInit.
|
|
BytesConsumed: 14,
|
|
}
|
|
}
|
|
} else {
|
|
// Add delay to test concurrent calls.
|
|
cleanup := stat.Update()
|
|
defer cleanup()
|
|
time.Sleep(t.delay)
|
|
|
|
// Generate the response to be returned by Recv() for the
|
|
// follow-up handshaking.
|
|
result := &altspb.HandshakerResult{
|
|
RecordProtocol: testRecordProtocol,
|
|
KeyData: testKey,
|
|
}
|
|
resp = &altspb.HandshakerResp{
|
|
Result: result,
|
|
// Simulate consuming ClientFinished or ServerFinished.
|
|
BytesConsumed: 18,
|
|
}
|
|
}
|
|
t.recvBuf = resp
|
|
return nil
|
|
}
|
|
|
|
func (t *testRPCStream) CloseSend() error {
|
|
return nil
|
|
}
|
|
|
|
var stat testutil.Stats
|
|
|
|
func (s) TestClientHandshake(t *testing.T) {
|
|
for _, testCase := range []struct {
|
|
delay time.Duration
|
|
numberOfHandshakes int
|
|
}{
|
|
{0 * time.Millisecond, 1},
|
|
{100 * time.Millisecond, 10 * maxPendingHandshakes},
|
|
} {
|
|
errc := make(chan error)
|
|
stat.Reset()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
|
|
for i := 0; i < testCase.numberOfHandshakes; i++ {
|
|
stream := &testRPCStream{
|
|
t: t,
|
|
isClient: true,
|
|
}
|
|
// Preload the inbound frames.
|
|
f1 := testutil.MakeFrame("ServerInit")
|
|
f2 := testutil.MakeFrame("ServerFinished")
|
|
in := bytes.NewBuffer(f1)
|
|
in.Write(f2)
|
|
out := new(bytes.Buffer)
|
|
tc := testutil.NewTestConn(in, out)
|
|
chs := &altsHandshaker{
|
|
stream: stream,
|
|
conn: tc,
|
|
clientOpts: &ClientHandshakerOptions{
|
|
TargetServiceAccounts: testTargetServiceAccounts,
|
|
ClientIdentity: testClientIdentity,
|
|
},
|
|
side: core.ClientSide,
|
|
}
|
|
go func() {
|
|
_, context, err := chs.ClientHandshake(ctx)
|
|
if err == nil && context == nil {
|
|
errc <- errors.New("expected non-nil ALTS context")
|
|
return
|
|
}
|
|
errc <- err
|
|
chs.Close()
|
|
}()
|
|
}
|
|
|
|
// Ensure all errors are expected.
|
|
for i := 0; i < testCase.numberOfHandshakes; i++ {
|
|
if err := <-errc; err != nil && err != errDropped {
|
|
t.Errorf("ClientHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
|
|
}
|
|
}
|
|
|
|
// Ensure that there are no concurrent calls more than the limit.
|
|
if stat.MaxConcurrentCalls > maxPendingHandshakes {
|
|
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s) TestServerHandshake(t *testing.T) {
|
|
for _, testCase := range []struct {
|
|
delay time.Duration
|
|
numberOfHandshakes int
|
|
}{
|
|
{0 * time.Millisecond, 1},
|
|
{100 * time.Millisecond, 10 * maxPendingHandshakes},
|
|
} {
|
|
errc := make(chan error)
|
|
stat.Reset()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
|
|
for i := 0; i < testCase.numberOfHandshakes; i++ {
|
|
stream := &testRPCStream{
|
|
t: t,
|
|
isClient: false,
|
|
}
|
|
// Preload the inbound frames.
|
|
f1 := testutil.MakeFrame("ClientInit")
|
|
f2 := testutil.MakeFrame("ClientFinished")
|
|
in := bytes.NewBuffer(f1)
|
|
in.Write(f2)
|
|
out := new(bytes.Buffer)
|
|
tc := testutil.NewTestConn(in, out)
|
|
shs := &altsHandshaker{
|
|
stream: stream,
|
|
conn: tc,
|
|
serverOpts: DefaultServerHandshakerOptions(),
|
|
side: core.ServerSide,
|
|
}
|
|
go func() {
|
|
_, context, err := shs.ServerHandshake(ctx)
|
|
if err == nil && context == nil {
|
|
errc <- errors.New("expected non-nil ALTS context")
|
|
return
|
|
}
|
|
errc <- err
|
|
shs.Close()
|
|
}()
|
|
}
|
|
|
|
// Ensure all errors are expected.
|
|
for i := 0; i < testCase.numberOfHandshakes; i++ {
|
|
if err := <-errc; err != nil && err != errDropped {
|
|
t.Errorf("ServerHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
|
|
}
|
|
}
|
|
|
|
// Ensure that there are no concurrent calls more than the limit.
|
|
if stat.MaxConcurrentCalls > maxPendingHandshakes {
|
|
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
|
|
}
|
|
}
|
|
}
|
|
|
|
// testUnresponsiveRPCStream is used for testing the PeerNotResponding case.
|
|
type testUnresponsiveRPCStream struct {
|
|
grpc.ClientStream
|
|
}
|
|
|
|
func (t *testUnresponsiveRPCStream) Recv() (*altspb.HandshakerResp, error) {
|
|
return &altspb.HandshakerResp{}, nil
|
|
}
|
|
|
|
func (t *testUnresponsiveRPCStream) Send(req *altspb.HandshakerReq) error {
|
|
return nil
|
|
}
|
|
|
|
func (t *testUnresponsiveRPCStream) CloseSend() error {
|
|
return nil
|
|
}
|
|
|
|
func (s) TestPeerNotResponding(t *testing.T) {
|
|
stream := &testUnresponsiveRPCStream{}
|
|
chs := &altsHandshaker{
|
|
stream: stream,
|
|
conn: testutil.NewUnresponsiveTestConn(),
|
|
clientOpts: &ClientHandshakerOptions{
|
|
TargetServiceAccounts: testTargetServiceAccounts,
|
|
ClientIdentity: testClientIdentity,
|
|
},
|
|
side: core.ClientSide,
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
_, context, err := chs.ClientHandshake(ctx)
|
|
chs.Close()
|
|
if context != nil {
|
|
t.Error("expected non-nil ALTS context")
|
|
}
|
|
if got, want := err, core.PeerNotRespondingError; got != want {
|
|
t.Errorf("ClientHandshake() = %v, want %v", got, want)
|
|
}
|
|
}
|