mirror of https://github.com/kubernetes/kops.git
				
				
				
			
		
			
				
	
	
		
			915 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			915 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Go
		
	
	
	
| /*
 | |
|  *
 | |
|  * Copyright 2016, Google Inc.
 | |
|  * All rights reserved.
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without
 | |
|  * modification, are permitted provided that the following conditions are
 | |
|  * met:
 | |
|  *
 | |
|  *     * Redistributions of source code must retain the above copyright
 | |
|  * notice, this list of conditions and the following disclaimer.
 | |
|  *     * Redistributions in binary form must reproduce the above
 | |
|  * copyright notice, this list of conditions and the following disclaimer
 | |
|  * in the documentation and/or other materials provided with the
 | |
|  * distribution.
 | |
|  *     * Neither the name of Google Inc. nor the names of its
 | |
|  * contributors may be used to endorse or promote products derived from
 | |
|  * this software without specific prior written permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 | |
|  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 | |
|  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 | |
|  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 | |
|  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 | |
|  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 | |
|  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 | |
|  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 | |
|  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 | |
|  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  */
 | |
| 
 | |
| package grpclb
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/golang/protobuf/proto"
 | |
| 	"golang.org/x/net/context"
 | |
| 	"google.golang.org/grpc"
 | |
| 	"google.golang.org/grpc/codes"
 | |
| 	"google.golang.org/grpc/credentials"
 | |
| 	lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
 | |
| 	"google.golang.org/grpc/metadata"
 | |
| 	"google.golang.org/grpc/naming"
 | |
| 	testpb "google.golang.org/grpc/test/grpc_testing"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	lbsn    = "bar.com"
 | |
| 	besn    = "foo.com"
 | |
| 	lbToken = "iamatoken"
 | |
| )
 | |
| 
 | |
| type testWatcher struct {
 | |
| 	// the channel to receives name resolution updates
 | |
| 	update chan *naming.Update
 | |
| 	// the side channel to get to know how many updates in a batch
 | |
| 	side chan int
 | |
| 	// the channel to notifiy update injector that the update reading is done
 | |
| 	readDone chan int
 | |
| }
 | |
| 
 | |
| func (w *testWatcher) Next() (updates []*naming.Update, err error) {
 | |
| 	n, ok := <-w.side
 | |
| 	if !ok {
 | |
| 		return nil, fmt.Errorf("w.side is closed")
 | |
| 	}
 | |
| 	for i := 0; i < n; i++ {
 | |
| 		u, ok := <-w.update
 | |
| 		if !ok {
 | |
| 			break
 | |
| 		}
 | |
| 		if u != nil {
 | |
| 			updates = append(updates, u)
 | |
| 		}
 | |
| 	}
 | |
| 	w.readDone <- 0
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (w *testWatcher) Close() {
 | |
| }
 | |
| 
 | |
| // Inject naming resolution updates to the testWatcher.
 | |
| func (w *testWatcher) inject(updates []*naming.Update) {
 | |
| 	w.side <- len(updates)
 | |
| 	for _, u := range updates {
 | |
| 		w.update <- u
 | |
| 	}
 | |
| 	<-w.readDone
 | |
| }
 | |
| 
 | |
| type testNameResolver struct {
 | |
| 	w     *testWatcher
 | |
| 	addrs []string
 | |
| }
 | |
| 
 | |
| func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
 | |
| 	r.w = &testWatcher{
 | |
| 		update:   make(chan *naming.Update, len(r.addrs)),
 | |
| 		side:     make(chan int, 1),
 | |
| 		readDone: make(chan int),
 | |
| 	}
 | |
| 	r.w.side <- len(r.addrs)
 | |
| 	for _, addr := range r.addrs {
 | |
| 		r.w.update <- &naming.Update{
 | |
| 			Op:   naming.Add,
 | |
| 			Addr: addr,
 | |
| 			Metadata: &grpc.AddrMetadataGRPCLB{
 | |
| 				AddrType:   grpc.GRPCLB,
 | |
| 				ServerName: lbsn,
 | |
| 			},
 | |
| 		}
 | |
| 	}
 | |
| 	go func() {
 | |
| 		<-r.w.readDone
 | |
| 	}()
 | |
| 	return r.w, nil
 | |
| }
 | |
| 
 | |
| func (r *testNameResolver) inject(updates []*naming.Update) {
 | |
| 	if r.w != nil {
 | |
| 		r.w.inject(updates)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type serverNameCheckCreds struct {
 | |
| 	expected string
 | |
| 	sn       string
 | |
| }
 | |
| 
 | |
| func (c *serverNameCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
 | |
| 	if _, err := io.WriteString(rawConn, c.sn); err != nil {
 | |
| 		fmt.Printf("Failed to write the server name %s to the client %v", c.sn, err)
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 	return rawConn, nil, nil
 | |
| }
 | |
| func (c *serverNameCheckCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
 | |
| 	b := make([]byte, len(c.expected))
 | |
| 	if _, err := rawConn.Read(b); err != nil {
 | |
| 		fmt.Printf("Failed to read the server name from the server %v", err)
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 	if c.expected != string(b) {
 | |
| 		fmt.Printf("Read the server name %s want %s", string(b), c.expected)
 | |
| 		return nil, nil, errors.New("received unexpected server name")
 | |
| 	}
 | |
| 	return rawConn, nil, nil
 | |
| }
 | |
| func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo {
 | |
| 	return credentials.ProtocolInfo{}
 | |
| }
 | |
| func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials {
 | |
| 	return &serverNameCheckCreds{
 | |
| 		expected: c.expected,
 | |
| 	}
 | |
| }
 | |
| func (c *serverNameCheckCreds) OverrideServerName(s string) error {
 | |
| 	c.expected = s
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| type remoteBalancer struct {
 | |
| 	sls       []*lbpb.ServerList
 | |
| 	intervals []time.Duration
 | |
| 	statsDura time.Duration
 | |
| 	done      chan struct{}
 | |
| 	mu        sync.Mutex
 | |
| 	stats     lbpb.ClientStats
 | |
| }
 | |
| 
 | |
| func newRemoteBalancer(sls []*lbpb.ServerList, intervals []time.Duration) *remoteBalancer {
 | |
| 	return &remoteBalancer{
 | |
| 		sls:       sls,
 | |
| 		intervals: intervals,
 | |
| 		done:      make(chan struct{}),
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (b *remoteBalancer) stop() {
 | |
| 	close(b.done)
 | |
| }
 | |
| 
 | |
| func (b *remoteBalancer) BalanceLoad(stream *loadBalancerBalanceLoadServer) error {
 | |
| 	req, err := stream.Recv()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	initReq := req.GetInitialRequest()
 | |
| 	if initReq.Name != besn {
 | |
| 		return grpc.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name)
 | |
| 	}
 | |
| 	resp := &lbpb.LoadBalanceResponse{
 | |
| 		LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
 | |
| 			InitialResponse: &lbpb.InitialLoadBalanceResponse{
 | |
| 				ClientStatsReportInterval: &lbpb.Duration{
 | |
| 					Seconds: int64(b.statsDura.Seconds()),
 | |
| 					Nanos:   int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9),
 | |
| 				},
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 	if err := stream.Send(resp); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	go func() {
 | |
| 		for {
 | |
| 			var (
 | |
| 				req *lbpb.LoadBalanceRequest
 | |
| 				err error
 | |
| 			)
 | |
| 			if req, err = stream.Recv(); err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 			b.mu.Lock()
 | |
| 			b.stats.NumCallsStarted += req.GetClientStats().NumCallsStarted
 | |
| 			b.stats.NumCallsFinished += req.GetClientStats().NumCallsFinished
 | |
| 			b.stats.NumCallsFinishedWithDropForRateLimiting += req.GetClientStats().NumCallsFinishedWithDropForRateLimiting
 | |
| 			b.stats.NumCallsFinishedWithDropForLoadBalancing += req.GetClientStats().NumCallsFinishedWithDropForLoadBalancing
 | |
| 			b.stats.NumCallsFinishedWithClientFailedToSend += req.GetClientStats().NumCallsFinishedWithClientFailedToSend
 | |
| 			b.stats.NumCallsFinishedKnownReceived += req.GetClientStats().NumCallsFinishedKnownReceived
 | |
| 			b.mu.Unlock()
 | |
| 		}
 | |
| 	}()
 | |
| 	for k, v := range b.sls {
 | |
| 		time.Sleep(b.intervals[k])
 | |
| 		resp = &lbpb.LoadBalanceResponse{
 | |
| 			LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
 | |
| 				ServerList: v,
 | |
| 			},
 | |
| 		}
 | |
| 		if err := stream.Send(resp); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 	<-b.done
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| type testServer struct {
 | |
| 	testpb.TestServiceServer
 | |
| 
 | |
| 	addr string
 | |
| }
 | |
| 
 | |
| const testmdkey = "testmd"
 | |
| 
 | |
| func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
 | |
| 	md, ok := metadata.FromIncomingContext(ctx)
 | |
| 	if !ok {
 | |
| 		return nil, grpc.Errorf(codes.Internal, "failed to receive metadata")
 | |
| 	}
 | |
| 	if md == nil || md["lb-token"][0] != lbToken {
 | |
| 		return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md)
 | |
| 	}
 | |
| 	grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
 | |
| 	return &testpb.Empty{}, nil
 | |
| }
 | |
| 
 | |
| func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) {
 | |
| 	for _, l := range lis {
 | |
| 		creds := &serverNameCheckCreds{
 | |
| 			sn: sn,
 | |
| 		}
 | |
| 		s := grpc.NewServer(grpc.Creds(creds))
 | |
| 		testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String()})
 | |
| 		servers = append(servers, s)
 | |
| 		go func(s *grpc.Server, l net.Listener) {
 | |
| 			s.Serve(l)
 | |
| 		}(s, l)
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func stopBackends(servers []*grpc.Server) {
 | |
| 	for _, s := range servers {
 | |
| 		s.Stop()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type testServers struct {
 | |
| 	lbAddr  string
 | |
| 	ls      *remoteBalancer
 | |
| 	lb      *grpc.Server
 | |
| 	beIPs   []net.IP
 | |
| 	bePorts []int
 | |
| }
 | |
| 
 | |
| func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), err error) {
 | |
| 	var (
 | |
| 		beListeners []net.Listener
 | |
| 		ls          *remoteBalancer
 | |
| 		lb          *grpc.Server
 | |
| 		beIPs       []net.IP
 | |
| 		bePorts     []int
 | |
| 	)
 | |
| 	for i := 0; i < numberOfBackends; i++ {
 | |
| 		// Start a backend.
 | |
| 		beLis, e := net.Listen("tcp", "localhost:0")
 | |
| 		if e != nil {
 | |
| 			err = fmt.Errorf("Failed to listen %v", err)
 | |
| 			return
 | |
| 		}
 | |
| 		beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP)
 | |
| 
 | |
| 		beAddr := strings.Split(beLis.Addr().String(), ":")
 | |
| 		bePort, _ := strconv.Atoi(beAddr[1])
 | |
| 		bePorts = append(bePorts, bePort)
 | |
| 
 | |
| 		beListeners = append(beListeners, beLis)
 | |
| 	}
 | |
| 	backends := startBackends(besn, beListeners...)
 | |
| 
 | |
| 	// Start a load balancer.
 | |
| 	lbLis, err := net.Listen("tcp", "localhost:0")
 | |
| 	if err != nil {
 | |
| 		err = fmt.Errorf("Failed to create the listener for the load balancer %v", err)
 | |
| 		return
 | |
| 	}
 | |
| 	lbCreds := &serverNameCheckCreds{
 | |
| 		sn: lbsn,
 | |
| 	}
 | |
| 	lb = grpc.NewServer(grpc.Creds(lbCreds))
 | |
| 	if err != nil {
 | |
| 		err = fmt.Errorf("Failed to generate the port number %v", err)
 | |
| 		return
 | |
| 	}
 | |
| 	ls = newRemoteBalancer(nil, nil)
 | |
| 	registerLoadBalancerServer(lb, ls)
 | |
| 	go func() {
 | |
| 		lb.Serve(lbLis)
 | |
| 	}()
 | |
| 
 | |
| 	tss = &testServers{
 | |
| 		lbAddr:  lbLis.Addr().String(),
 | |
| 		ls:      ls,
 | |
| 		lb:      lb,
 | |
| 		beIPs:   beIPs,
 | |
| 		bePorts: bePorts,
 | |
| 	}
 | |
| 	cleanup = func() {
 | |
| 		defer stopBackends(backends)
 | |
| 		defer func() {
 | |
| 			ls.stop()
 | |
| 			lb.Stop()
 | |
| 		}()
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func TestGRPCLB(t *testing.T) {
 | |
| 	tss, cleanup, err := newLoadBalancer(1)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("failed to create new load balancer: %v", err)
 | |
| 	}
 | |
| 	defer cleanup()
 | |
| 
 | |
| 	be := &lbpb.Server{
 | |
| 		IpAddress:        tss.beIPs[0],
 | |
| 		Port:             int32(tss.bePorts[0]),
 | |
| 		LoadBalanceToken: lbToken,
 | |
| 	}
 | |
| 	var bes []*lbpb.Server
 | |
| 	bes = append(bes, be)
 | |
| 	sl := &lbpb.ServerList{
 | |
| 		Servers: bes,
 | |
| 	}
 | |
| 	tss.ls.sls = []*lbpb.ServerList{sl}
 | |
| 	tss.ls.intervals = []time.Duration{0}
 | |
| 	creds := serverNameCheckCreds{
 | |
| 		expected: besn,
 | |
| 	}
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 | |
| 	defer cancel()
 | |
| 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
 | |
| 		addrs: []string{tss.lbAddr},
 | |
| 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to dial to the backend %v", err)
 | |
| 	}
 | |
| 	testC := testpb.NewTestServiceClient(cc)
 | |
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
 | |
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 	}
 | |
| 	cc.Close()
 | |
| }
 | |
| 
 | |
| func TestDropRequest(t *testing.T) {
 | |
| 	tss, cleanup, err := newLoadBalancer(2)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("failed to create new load balancer: %v", err)
 | |
| 	}
 | |
| 	defer cleanup()
 | |
| 	tss.ls.sls = []*lbpb.ServerList{{
 | |
| 		Servers: []*lbpb.Server{{
 | |
| 			IpAddress:            tss.beIPs[0],
 | |
| 			Port:                 int32(tss.bePorts[0]),
 | |
| 			LoadBalanceToken:     lbToken,
 | |
| 			DropForLoadBalancing: true,
 | |
| 		}, {
 | |
| 			IpAddress:            tss.beIPs[1],
 | |
| 			Port:                 int32(tss.bePorts[1]),
 | |
| 			LoadBalanceToken:     lbToken,
 | |
| 			DropForLoadBalancing: false,
 | |
| 		}},
 | |
| 	}}
 | |
| 	tss.ls.intervals = []time.Duration{0}
 | |
| 	creds := serverNameCheckCreds{
 | |
| 		expected: besn,
 | |
| 	}
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 | |
| 	defer cancel()
 | |
| 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
 | |
| 		addrs: []string{tss.lbAddr},
 | |
| 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to dial to the backend %v", err)
 | |
| 	}
 | |
| 	testC := testpb.NewTestServiceClient(cc)
 | |
| 	// The 1st, non-fail-fast RPC should succeed.  This ensures both server
 | |
| 	// connections are made, because the first one has DropForLoadBalancing set to true.
 | |
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
 | |
| 		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 	}
 | |
| 	for i := 0; i < 3; i++ {
 | |
| 		// Odd fail-fast RPCs should fail, because the 1st backend has DropForLoadBalancing
 | |
| 		// set to true.
 | |
| 		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
 | |
| 			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
 | |
| 		}
 | |
| 		// Even fail-fast RPCs should succeed since they choose the
 | |
| 		// non-drop-request backend according to the round robin policy.
 | |
| 		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
 | |
| 			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 		}
 | |
| 	}
 | |
| 	cc.Close()
 | |
| }
 | |
| 
 | |
| func TestDropRequestFailedNonFailFast(t *testing.T) {
 | |
| 	tss, cleanup, err := newLoadBalancer(1)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("failed to create new load balancer: %v", err)
 | |
| 	}
 | |
| 	defer cleanup()
 | |
| 	be := &lbpb.Server{
 | |
| 		IpAddress:            tss.beIPs[0],
 | |
| 		Port:                 int32(tss.bePorts[0]),
 | |
| 		LoadBalanceToken:     lbToken,
 | |
| 		DropForLoadBalancing: true,
 | |
| 	}
 | |
| 	var bes []*lbpb.Server
 | |
| 	bes = append(bes, be)
 | |
| 	sl := &lbpb.ServerList{
 | |
| 		Servers: bes,
 | |
| 	}
 | |
| 	tss.ls.sls = []*lbpb.ServerList{sl}
 | |
| 	tss.ls.intervals = []time.Duration{0}
 | |
| 	creds := serverNameCheckCreds{
 | |
| 		expected: besn,
 | |
| 	}
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 | |
| 	defer cancel()
 | |
| 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
 | |
| 		addrs: []string{tss.lbAddr},
 | |
| 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to dial to the backend %v", err)
 | |
| 	}
 | |
| 	testC := testpb.NewTestServiceClient(cc)
 | |
| 	ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond)
 | |
| 	defer cancel()
 | |
| 	if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
 | |
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded)
 | |
| 	}
 | |
| 	cc.Close()
 | |
| }
 | |
| 
 | |
| func TestServerExpiration(t *testing.T) {
 | |
| 	tss, cleanup, err := newLoadBalancer(1)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("failed to create new load balancer: %v", err)
 | |
| 	}
 | |
| 	defer cleanup()
 | |
| 	be := &lbpb.Server{
 | |
| 		IpAddress:        tss.beIPs[0],
 | |
| 		Port:             int32(tss.bePorts[0]),
 | |
| 		LoadBalanceToken: lbToken,
 | |
| 	}
 | |
| 	var bes []*lbpb.Server
 | |
| 	bes = append(bes, be)
 | |
| 	exp := &lbpb.Duration{
 | |
| 		Seconds: 0,
 | |
| 		Nanos:   100000000, // 100ms
 | |
| 	}
 | |
| 	var sls []*lbpb.ServerList
 | |
| 	sl := &lbpb.ServerList{
 | |
| 		Servers:            bes,
 | |
| 		ExpirationInterval: exp,
 | |
| 	}
 | |
| 	sls = append(sls, sl)
 | |
| 	sl = &lbpb.ServerList{
 | |
| 		Servers: bes,
 | |
| 	}
 | |
| 	sls = append(sls, sl)
 | |
| 	var intervals []time.Duration
 | |
| 	intervals = append(intervals, 0)
 | |
| 	intervals = append(intervals, 500*time.Millisecond)
 | |
| 	tss.ls.sls = sls
 | |
| 	tss.ls.intervals = intervals
 | |
| 	creds := serverNameCheckCreds{
 | |
| 		expected: besn,
 | |
| 	}
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 | |
| 	defer cancel()
 | |
| 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
 | |
| 		addrs: []string{tss.lbAddr},
 | |
| 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to dial to the backend %v", err)
 | |
| 	}
 | |
| 	testC := testpb.NewTestServiceClient(cc)
 | |
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
 | |
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 	}
 | |
| 	// Sleep and wake up when the first server list gets expired.
 | |
| 	time.Sleep(150 * time.Millisecond)
 | |
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
 | |
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
 | |
| 	}
 | |
| 	// A non-failfast rpc should be succeeded after the second server list is received from
 | |
| 	// the remote load balancer.
 | |
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
 | |
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 	}
 | |
| 	cc.Close()
 | |
| }
 | |
| 
 | |
| // When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list.
 | |
| func TestBalancerDisconnects(t *testing.T) {
 | |
| 	var (
 | |
| 		lbAddrs []string
 | |
| 		lbs     []*grpc.Server
 | |
| 	)
 | |
| 	for i := 0; i < 3; i++ {
 | |
| 		tss, cleanup, err := newLoadBalancer(1)
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("failed to create new load balancer: %v", err)
 | |
| 		}
 | |
| 		defer cleanup()
 | |
| 
 | |
| 		be := &lbpb.Server{
 | |
| 			IpAddress:        tss.beIPs[0],
 | |
| 			Port:             int32(tss.bePorts[0]),
 | |
| 			LoadBalanceToken: lbToken,
 | |
| 		}
 | |
| 		var bes []*lbpb.Server
 | |
| 		bes = append(bes, be)
 | |
| 		sl := &lbpb.ServerList{
 | |
| 			Servers: bes,
 | |
| 		}
 | |
| 		tss.ls.sls = []*lbpb.ServerList{sl}
 | |
| 		tss.ls.intervals = []time.Duration{0}
 | |
| 
 | |
| 		lbAddrs = append(lbAddrs, tss.lbAddr)
 | |
| 		lbs = append(lbs, tss.lb)
 | |
| 	}
 | |
| 
 | |
| 	creds := serverNameCheckCreds{
 | |
| 		expected: besn,
 | |
| 	}
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 | |
| 	defer cancel()
 | |
| 	resolver := &testNameResolver{
 | |
| 		addrs: lbAddrs[:2],
 | |
| 	}
 | |
| 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(resolver)), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to dial to the backend %v", err)
 | |
| 	}
 | |
| 	testC := testpb.NewTestServiceClient(cc)
 | |
| 	var previousTrailer string
 | |
| 	trailer := metadata.MD{}
 | |
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil {
 | |
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 	} else {
 | |
| 		previousTrailer = trailer[testmdkey][0]
 | |
| 	}
 | |
| 	// The initial resolver update contains lbs[0] and lbs[1].
 | |
| 	// When lbs[0] is stopped, lbs[1] should be used.
 | |
| 	lbs[0].Stop()
 | |
| 	for {
 | |
| 		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil {
 | |
| 			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 		} else if trailer[testmdkey][0] != previousTrailer {
 | |
| 			// A new backend server should receive the request.
 | |
| 			// The trailer contains the backend address, so the trailer should be different from the previous one.
 | |
| 			previousTrailer = trailer[testmdkey][0]
 | |
| 			break
 | |
| 		}
 | |
| 		time.Sleep(100 * time.Millisecond)
 | |
| 	}
 | |
| 	// Inject a update to add lbs[2] to resolved addresses.
 | |
| 	resolver.inject([]*naming.Update{
 | |
| 		{Op: naming.Add,
 | |
| 			Addr: lbAddrs[2],
 | |
| 			Metadata: &grpc.AddrMetadataGRPCLB{
 | |
| 				AddrType:   grpc.GRPCLB,
 | |
| 				ServerName: lbsn,
 | |
| 			},
 | |
| 		},
 | |
| 	})
 | |
| 	// Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used.
 | |
| 	lbs[1].Stop()
 | |
| 	for {
 | |
| 		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil {
 | |
| 			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 		} else if trailer[testmdkey][0] != previousTrailer {
 | |
| 			// A new backend server should receive the request.
 | |
| 			// The trailer contains the backend address, so the trailer should be different from the previous one.
 | |
| 			break
 | |
| 		}
 | |
| 		time.Sleep(100 * time.Millisecond)
 | |
| 	}
 | |
| 	cc.Close()
 | |
| }
 | |
| 
 | |
| type failPreRPCCred struct{}
 | |
| 
 | |
| func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
 | |
| 	if strings.Contains(uri[0], "failtosend") {
 | |
| 		return nil, fmt.Errorf("rpc should fail to send")
 | |
| 	}
 | |
| 	return nil, nil
 | |
| }
 | |
| 
 | |
| func (failPreRPCCred) RequireTransportSecurity() bool {
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func checkStats(stats *lbpb.ClientStats, expected *lbpb.ClientStats) error {
 | |
| 	if !proto.Equal(stats, expected) {
 | |
| 		return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbpb.ClientStats {
 | |
| 	tss, cleanup, err := newLoadBalancer(3)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("failed to create new load balancer: %v", err)
 | |
| 	}
 | |
| 	defer cleanup()
 | |
| 	tss.ls.sls = []*lbpb.ServerList{{
 | |
| 		Servers: []*lbpb.Server{{
 | |
| 			IpAddress:            tss.beIPs[2],
 | |
| 			Port:                 int32(tss.bePorts[2]),
 | |
| 			LoadBalanceToken:     lbToken,
 | |
| 			DropForLoadBalancing: dropForLoadBalancing,
 | |
| 			DropForRateLimiting:  dropForRateLimiting,
 | |
| 		}},
 | |
| 	}}
 | |
| 	tss.ls.intervals = []time.Duration{0}
 | |
| 	tss.ls.statsDura = 100 * time.Millisecond
 | |
| 	creds := serverNameCheckCreds{expected: besn}
 | |
| 
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 | |
| 	defer cancel()
 | |
| 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
 | |
| 		addrs: []string{tss.lbAddr},
 | |
| 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}))
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed to dial to the backend %v", err)
 | |
| 	}
 | |
| 	defer cc.Close()
 | |
| 
 | |
| 	runRPCs(cc)
 | |
| 	time.Sleep(1 * time.Second)
 | |
| 	tss.ls.mu.Lock()
 | |
| 	stats := tss.ls.stats
 | |
| 	tss.ls.mu.Unlock()
 | |
| 	return stats
 | |
| }
 | |
| 
 | |
| const countRPC = 40
 | |
| 
 | |
| func TestGRPCLBStatsUnarySuccess(t *testing.T) {
 | |
| 	stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
 | |
| 		testC := testpb.NewTestServiceClient(cc)
 | |
| 		// The first non-failfast RPC succeeds, all connections are up.
 | |
| 		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
 | |
| 			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 		}
 | |
| 		for i := 0; i < countRPC-1; i++ {
 | |
| 			testC.EmptyCall(context.Background(), &testpb.Empty{})
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	if err := checkStats(&stats, &lbpb.ClientStats{
 | |
| 		NumCallsStarted:               int64(countRPC),
 | |
| 		NumCallsFinished:              int64(countRPC),
 | |
| 		NumCallsFinishedKnownReceived: int64(countRPC),
 | |
| 	}); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) {
 | |
| 	c := 0
 | |
| 	stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) {
 | |
| 		testC := testpb.NewTestServiceClient(cc)
 | |
| 		for {
 | |
| 			c++
 | |
| 			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
 | |
| 				if strings.Contains(err.Error(), "drops requests") {
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		for i := 0; i < countRPC; i++ {
 | |
| 			testC.EmptyCall(context.Background(), &testpb.Empty{})
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	if err := checkStats(&stats, &lbpb.ClientStats{
 | |
| 		NumCallsStarted:                          int64(countRPC + c),
 | |
| 		NumCallsFinished:                         int64(countRPC + c),
 | |
| 		NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1),
 | |
| 		NumCallsFinishedWithClientFailedToSend:   int64(c - 1),
 | |
| 	}); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) {
 | |
| 	c := 0
 | |
| 	stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) {
 | |
| 		testC := testpb.NewTestServiceClient(cc)
 | |
| 		for {
 | |
| 			c++
 | |
| 			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
 | |
| 				if strings.Contains(err.Error(), "drops requests") {
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		for i := 0; i < countRPC; i++ {
 | |
| 			testC.EmptyCall(context.Background(), &testpb.Empty{})
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	if err := checkStats(&stats, &lbpb.ClientStats{
 | |
| 		NumCallsStarted:                         int64(countRPC + c),
 | |
| 		NumCallsFinished:                        int64(countRPC + c),
 | |
| 		NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1),
 | |
| 		NumCallsFinishedWithClientFailedToSend:  int64(c - 1),
 | |
| 	}); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
 | |
| 	stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
 | |
| 		testC := testpb.NewTestServiceClient(cc)
 | |
| 		// The first non-failfast RPC succeeds, all connections are up.
 | |
| 		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
 | |
| 			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 		}
 | |
| 		for i := 0; i < countRPC-1; i++ {
 | |
| 			grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc)
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	if err := checkStats(&stats, &lbpb.ClientStats{
 | |
| 		NumCallsStarted:                        int64(countRPC),
 | |
| 		NumCallsFinished:                       int64(countRPC),
 | |
| 		NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
 | |
| 		NumCallsFinishedKnownReceived:          1,
 | |
| 	}); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestGRPCLBStatsStreamingSuccess(t *testing.T) {
 | |
| 	stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
 | |
| 		testC := testpb.NewTestServiceClient(cc)
 | |
| 		// The first non-failfast RPC succeeds, all connections are up.
 | |
| 		stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 		}
 | |
| 		for {
 | |
| 			if _, err = stream.Recv(); err == io.EOF {
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 		for i := 0; i < countRPC-1; i++ {
 | |
| 			stream, err = testC.FullDuplexCall(context.Background())
 | |
| 			if err == nil {
 | |
| 				// Wait for stream to end if err is nil.
 | |
| 				for {
 | |
| 					if _, err = stream.Recv(); err == io.EOF {
 | |
| 						break
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	if err := checkStats(&stats, &lbpb.ClientStats{
 | |
| 		NumCallsStarted:               int64(countRPC),
 | |
| 		NumCallsFinished:              int64(countRPC),
 | |
| 		NumCallsFinishedKnownReceived: int64(countRPC),
 | |
| 	}); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) {
 | |
| 	c := 0
 | |
| 	stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) {
 | |
| 		testC := testpb.NewTestServiceClient(cc)
 | |
| 		for {
 | |
| 			c++
 | |
| 			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
 | |
| 				if strings.Contains(err.Error(), "drops requests") {
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		for i := 0; i < countRPC; i++ {
 | |
| 			testC.FullDuplexCall(context.Background())
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	if err := checkStats(&stats, &lbpb.ClientStats{
 | |
| 		NumCallsStarted:                          int64(countRPC + c),
 | |
| 		NumCallsFinished:                         int64(countRPC + c),
 | |
| 		NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1),
 | |
| 		NumCallsFinishedWithClientFailedToSend:   int64(c - 1),
 | |
| 	}); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) {
 | |
| 	c := 0
 | |
| 	stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) {
 | |
| 		testC := testpb.NewTestServiceClient(cc)
 | |
| 		for {
 | |
| 			c++
 | |
| 			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
 | |
| 				if strings.Contains(err.Error(), "drops requests") {
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		for i := 0; i < countRPC; i++ {
 | |
| 			testC.FullDuplexCall(context.Background())
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	if err := checkStats(&stats, &lbpb.ClientStats{
 | |
| 		NumCallsStarted:                         int64(countRPC + c),
 | |
| 		NumCallsFinished:                        int64(countRPC + c),
 | |
| 		NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1),
 | |
| 		NumCallsFinishedWithClientFailedToSend:  int64(c - 1),
 | |
| 	}); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
 | |
| 	stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
 | |
| 		testC := testpb.NewTestServiceClient(cc)
 | |
| 		// The first non-failfast RPC succeeds, all connections are up.
 | |
| 		stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
 | |
| 		}
 | |
| 		for {
 | |
| 			if _, err = stream.Recv(); err == io.EOF {
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 		for i := 0; i < countRPC-1; i++ {
 | |
| 			grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend")
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	if err := checkStats(&stats, &lbpb.ClientStats{
 | |
| 		NumCallsStarted:                        int64(countRPC),
 | |
| 		NumCallsFinished:                       int64(countRPC),
 | |
| 		NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
 | |
| 		NumCallsFinishedKnownReceived:          1,
 | |
| 	}); err != nil {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 |