From 8a6eb0f6e96a246c7c655b0207b62255042f6fbd Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Thu, 13 Apr 2017 13:41:35 -0700 Subject: [PATCH] grpclb should connect to the second balancer (#1181) grpclb needs to connect the second resolved balancer address when the first balancer disconnects. If grpclb gets 2 resolved addresses: balancer1 and balancer2. When balancer1 disconnects, grpclb should automatically start to use balancer2. --- grpclb/grpclb.go | 128 ++++++++++----- grpclb/grpclb_test.go | 357 ++++++++++++++++++++++++------------------ 2 files changed, 298 insertions(+), 187 deletions(-) diff --git a/grpclb/grpclb.go b/grpclb/grpclb.go index b699a525a..ea065fab6 100644 --- a/grpclb/grpclb.go +++ b/grpclb/grpclb.go @@ -111,7 +111,7 @@ type balancer struct { rand *rand.Rand } -func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo) error { +func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error { updates, err := w.Next() if err != nil { return err @@ -121,10 +121,6 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo if b.done { return grpc.ErrClientConnClosing } - var bAddr remoteBalancerInfo - if len(b.rbs) > 0 { - bAddr = b.rbs[0] - } for _, update := range updates { switch update.Op { case naming.Add: @@ -173,21 +169,11 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo } // TODO: Fall back to the basic round-robin load balancing if the resulting address is // not a load balancer. - if len(b.rbs) > 0 { - // For simplicity, always use the first one now. May revisit this decision later. - if b.rbs[0] != bAddr { - select { - case <-ch: - default: - } - // Pick a random one from the list, instead of always using the first one. - if l := len(b.rbs); l > 1 { - tmpIdx := b.rand.Intn(l - 1) - b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0] - } - ch <- b.rbs[0] - } + select { + case <-ch: + default: } + ch <- b.rbs return nil } @@ -261,7 +247,7 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient, seq int) (retry bool) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false)) + stream, err := lbc.BalanceLoad(ctx) if err != nil { grpclog.Printf("Failed to perform RPC to the remote balancer %v", err) return @@ -340,32 +326,98 @@ func (b *balancer) Start(target string, config grpc.BalancerConfig) error { } b.w = w b.mu.Unlock() - balancerAddrCh := make(chan remoteBalancerInfo, 1) + balancerAddrsCh := make(chan []remoteBalancerInfo, 1) // Spawn a goroutine to monitor the name resolution of remote load balancer. go func() { for { - if err := b.watchAddrUpdates(w, balancerAddrCh); err != nil { + if err := b.watchAddrUpdates(w, balancerAddrsCh); err != nil { grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err) - close(balancerAddrCh) + close(balancerAddrsCh) return } } }() // Spawn a goroutine to talk to the remote load balancer. go func() { - var cc *grpc.ClientConn - for { - rb, ok := <-balancerAddrCh + var ( + cc *grpc.ClientConn + // ccError is closed when there is an error in the current cc. + // A new rb should be picked from rbs and connected. + ccError chan struct{} + rb *remoteBalancerInfo + rbs []remoteBalancerInfo + rbIdx int + ) + + defer func() { + if ccError != nil { + select { + case <-ccError: + default: + close(ccError) + } + } if cc != nil { cc.Close() } - if !ok { - // b is closing. - return + }() + + for { + var ok bool + select { + case rbs, ok = <-balancerAddrsCh: + if !ok { + return + } + foundIdx := -1 + if rb != nil { + for i, trb := range rbs { + if trb == *rb { + foundIdx = i + break + } + } + } + if foundIdx >= 0 { + if foundIdx >= 1 { + // Move the address in use to the beginning of the list. + b.rbs[0], b.rbs[foundIdx] = b.rbs[foundIdx], b.rbs[0] + rbIdx = 0 + } + continue // If found, don't dial new cc. + } else if len(rbs) > 0 { + // Pick a random one from the list, instead of always using the first one. + if l := len(rbs); l > 1 && rb != nil { + tmpIdx := b.rand.Intn(l - 1) + b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0] + } + rbIdx = 0 + rb = &rbs[0] + } else { + // foundIdx < 0 && len(rbs) <= 0. + rb = nil + } + case <-ccError: + ccError = nil + if rbIdx < len(rbs)-1 { + rbIdx++ + rb = &rbs[rbIdx] + } else { + rb = nil + } + } + + if rb == nil { + continue + } + + if cc != nil { + cc.Close() } // Talk to the remote load balancer to get the server list. var err error creds := config.DialCreds + ccError = make(chan struct{}) if creds == nil { cc, err = grpc.Dial(rb.addr, grpc.WithInsecure()) } else { @@ -379,22 +431,24 @@ func (b *balancer) Start(target string, config grpc.BalancerConfig) error { } if err != nil { grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err) - return + close(ccError) + continue } b.mu.Lock() b.seq++ // tick when getting a new balancer address seq := b.seq b.next = 0 b.mu.Unlock() - go func(cc *grpc.ClientConn) { + go func(cc *grpc.ClientConn, ccError chan struct{}) { lbc := lbpb.NewLoadBalancerClient(cc) - for { - if retry := b.callRemoteBalancer(lbc, seq); !retry { - cc.Close() - return - } + b.callRemoteBalancer(lbc, seq) + cc.Close() + select { + case <-ccError: + default: + close(ccError) } - }(cc) + }(cc, ccError) } }() return nil diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go index ba7824c2e..f6115b28e 100644 --- a/grpclb/grpclb_test.go +++ b/grpclb/grpclb_test.go @@ -99,24 +99,26 @@ func (w *testWatcher) inject(updates []*naming.Update) { } type testNameResolver struct { - w *testWatcher - addr string + w *testWatcher + addrs []string } func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { r.w = &testWatcher{ - update: make(chan *naming.Update, 1), + update: make(chan *naming.Update, len(r.addrs)), side: make(chan int, 1), readDone: make(chan int), } - r.w.side <- 1 - r.w.update <- &naming.Update{ - Op: naming.Add, - Addr: r.addr, - Metadata: &Metadata{ - AddrType: GRPCLB, - ServerName: lbsn, - }, + r.w.side <- len(r.addrs) + for _, addr := range r.addrs { + r.w.update <- &naming.Update{ + Op: naming.Add, + Addr: addr, + Metadata: &Metadata{ + AddrType: GRPCLB, + ServerName: lbsn, + }, + } } go func() { <-r.w.readDone @@ -124,6 +126,12 @@ func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { return r.w, nil } +func (r *testNameResolver) inject(updates []*naming.Update) { + if r.w != nil { + r.w.inject(updates) + } +} + type serverNameCheckCreds struct { expected string sn string @@ -212,6 +220,7 @@ func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer) } type helloServer struct { + addr string } func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) { @@ -223,17 +232,17 @@ func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwp return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md) } return &hwpb.HelloReply{ - Message: "Hello " + in.Name, + Message: "Hello " + in.Name + " for " + s.addr, }, nil } -func startBackends(t *testing.T, sn string, lis ...net.Listener) (servers []*grpc.Server) { +func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) { for _, l := range lis { creds := &serverNameCheckCreds{ sn: sn, } s := grpc.NewServer(grpc.Creds(creds)) - hwpb.RegisterGreeterServer(s, &helloServer{}) + hwpb.RegisterGreeterServer(s, &helloServer{addr: l.Addr().String()}) servers = append(servers, s) go func(s *grpc.Server, l net.Listener) { s.Serve(l) @@ -248,32 +257,86 @@ func stopBackends(servers []*grpc.Server) { } } -func TestGRPCLB(t *testing.T) { - // Start a backend. - beLis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to listen %v", err) +type testServers struct { + lbAddr string + ls *remoteBalancer + lb *grpc.Server + beIPs []net.IP + bePorts []int +} + +func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), err error) { + var ( + beListeners []net.Listener + ls *remoteBalancer + lb *grpc.Server + beIPs []net.IP + bePorts []int + ) + for i := 0; i < numberOfBackends; i++ { + // Start a backend. + beLis, e := net.Listen("tcp", "localhost:0") + if e != nil { + err = fmt.Errorf("Failed to listen %v", err) + return + } + beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP) + + beAddr := strings.Split(beLis.Addr().String(), ":") + bePort, _ := strconv.Atoi(beAddr[1]) + bePorts = append(bePorts, bePort) + + beListeners = append(beListeners, beLis) } - beAddr := strings.Split(beLis.Addr().String(), ":") - bePort, err := strconv.Atoi(beAddr[1]) - backends := startBackends(t, besn, beLis) - defer stopBackends(backends) + backends := startBackends(besn, beListeners...) // Start a load balancer. lbLis, err := net.Listen("tcp", "localhost:0") if err != nil { - t.Fatalf("Failed to create the listener for the load balancer %v", err) + err = fmt.Errorf("Failed to create the listener for the load balancer %v", err) + return } lbCreds := &serverNameCheckCreds{ sn: lbsn, } - lb := grpc.NewServer(grpc.Creds(lbCreds)) + lb = grpc.NewServer(grpc.Creds(lbCreds)) if err != nil { - t.Fatalf("Failed to generate the port number %v", err) + err = fmt.Errorf("Failed to generate the port number %v", err) + return } + ls = newRemoteBalancer(nil, nil) + lbpb.RegisterLoadBalancerServer(lb, ls) + go func() { + lb.Serve(lbLis) + }() + + tss = &testServers{ + lbAddr: lbLis.Addr().String(), + ls: ls, + lb: lb, + beIPs: beIPs, + bePorts: bePorts, + } + cleanup = func() { + defer stopBackends(backends) + defer func() { + ls.stop() + lb.Stop() + }() + } + return +} + +func TestGRPCLB(t *testing.T) { + tss, cleanup, err := newLoadBalancer(1) + if err != nil { + t.Fatalf("failed to create new load balancer: %v", err) + } + defer cleanup() + be := &lbpb.Server{ - IpAddress: beLis.Addr().(*net.TCPAddr).IP, - Port: int32(bePort), + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), LoadBalanceToken: lbToken, } var bes []*lbpb.Server @@ -281,23 +344,14 @@ func TestGRPCLB(t *testing.T) { sl := &lbpb.ServerList{ Servers: bes, } - sls := []*lbpb.ServerList{sl} - intervals := []time.Duration{0} - ls := newRemoteBalancer(sls, intervals) - lbpb.RegisterLoadBalancerServer(lb, ls) - go func() { - lb.Serve(lbLis) - }() - defer func() { - ls.stop() - lb.Stop() - }() + tss.ls.sls = []*lbpb.ServerList{sl} + tss.ls.intervals = []time.Duration{0} creds := serverNameCheckCreds{ expected: besn, } ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{ - addr: lbLis.Addr().String(), + addrs: []string{tss.lbAddr}, })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) @@ -310,65 +364,31 @@ func TestGRPCLB(t *testing.T) { } func TestDropRequest(t *testing.T) { - // Start 2 backends. - beLis1, err := net.Listen("tcp", "localhost:0") + tss, cleanup, err := newLoadBalancer(2) if err != nil { - t.Fatalf("Failed to listen %v", err) + t.Fatalf("failed to create new load balancer: %v", err) } - beAddr1 := strings.Split(beLis1.Addr().String(), ":") - bePort1, err := strconv.Atoi(beAddr1[1]) - - beLis2, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to listen %v", err) - } - beAddr2 := strings.Split(beLis2.Addr().String(), ":") - bePort2, err := strconv.Atoi(beAddr2[1]) - - backends := startBackends(t, besn, beLis1, beLis2) - defer stopBackends(backends) - - // Start a load balancer. - lbLis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to create the listener for the load balancer %v", err) - } - lbCreds := &serverNameCheckCreds{ - sn: lbsn, - } - lb := grpc.NewServer(grpc.Creds(lbCreds)) - if err != nil { - t.Fatalf("Failed to generate the port number %v", err) - } - sls := []*lbpb.ServerList{{ + defer cleanup() + tss.ls.sls = []*lbpb.ServerList{{ Servers: []*lbpb.Server{{ - IpAddress: beLis1.Addr().(*net.TCPAddr).IP, - Port: int32(bePort1), + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), LoadBalanceToken: lbToken, DropRequest: true, }, { - IpAddress: beLis2.Addr().(*net.TCPAddr).IP, - Port: int32(bePort2), + IpAddress: tss.beIPs[1], + Port: int32(tss.bePorts[1]), LoadBalanceToken: lbToken, DropRequest: false, }}, }} - intervals := []time.Duration{0} - ls := newRemoteBalancer(sls, intervals) - lbpb.RegisterLoadBalancerServer(lb, ls) - go func() { - lb.Serve(lbLis) - }() - defer func() { - ls.stop() - lb.Stop() - }() + tss.ls.intervals = []time.Duration{0} creds := serverNameCheckCreds{ expected: besn, } ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{ - addr: lbLis.Addr().String(), + addrs: []string{tss.lbAddr}, })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) @@ -395,31 +415,14 @@ func TestDropRequest(t *testing.T) { } func TestDropRequestFailedNonFailFast(t *testing.T) { - // Start a backend. - beLis, err := net.Listen("tcp", "localhost:0") + tss, cleanup, err := newLoadBalancer(1) if err != nil { - t.Fatalf("Failed to listen %v", err) - } - beAddr := strings.Split(beLis.Addr().String(), ":") - bePort, err := strconv.Atoi(beAddr[1]) - backends := startBackends(t, besn, beLis) - defer stopBackends(backends) - - // Start a load balancer. - lbLis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to create the listener for the load balancer %v", err) - } - lbCreds := &serverNameCheckCreds{ - sn: lbsn, - } - lb := grpc.NewServer(grpc.Creds(lbCreds)) - if err != nil { - t.Fatalf("Failed to generate the port number %v", err) + t.Fatalf("failed to create new load balancer: %v", err) } + defer cleanup() be := &lbpb.Server{ - IpAddress: beLis.Addr().(*net.TCPAddr).IP, - Port: int32(bePort), + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), LoadBalanceToken: lbToken, DropRequest: true, } @@ -428,23 +431,14 @@ func TestDropRequestFailedNonFailFast(t *testing.T) { sl := &lbpb.ServerList{ Servers: bes, } - sls := []*lbpb.ServerList{sl} - intervals := []time.Duration{0} - ls := newRemoteBalancer(sls, intervals) - lbpb.RegisterLoadBalancerServer(lb, ls) - go func() { - lb.Serve(lbLis) - }() - defer func() { - ls.stop() - lb.Stop() - }() + tss.ls.sls = []*lbpb.ServerList{sl} + tss.ls.intervals = []time.Duration{0} creds := serverNameCheckCreds{ expected: besn, } ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{ - addr: lbLis.Addr().String(), + addrs: []string{tss.lbAddr}, })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) @@ -458,31 +452,14 @@ func TestDropRequestFailedNonFailFast(t *testing.T) { } func TestServerExpiration(t *testing.T) { - // Start a backend. - beLis, err := net.Listen("tcp", "localhost:0") + tss, cleanup, err := newLoadBalancer(1) if err != nil { - t.Fatalf("Failed to listen %v", err) - } - beAddr := strings.Split(beLis.Addr().String(), ":") - bePort, err := strconv.Atoi(beAddr[1]) - backends := startBackends(t, besn, beLis) - defer stopBackends(backends) - - // Start a load balancer. - lbLis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to create the listener for the load balancer %v", err) - } - lbCreds := &serverNameCheckCreds{ - sn: lbsn, - } - lb := grpc.NewServer(grpc.Creds(lbCreds)) - if err != nil { - t.Fatalf("Failed to generate the port number %v", err) + t.Fatalf("failed to create new load balancer: %v", err) } + defer cleanup() be := &lbpb.Server{ - IpAddress: beLis.Addr().(*net.TCPAddr).IP, - Port: int32(bePort), + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), LoadBalanceToken: lbToken, } var bes []*lbpb.Server @@ -504,21 +481,14 @@ func TestServerExpiration(t *testing.T) { var intervals []time.Duration intervals = append(intervals, 0) intervals = append(intervals, 500*time.Millisecond) - ls := newRemoteBalancer(sls, intervals) - lbpb.RegisterLoadBalancerServer(lb, ls) - go func() { - lb.Serve(lbLis) - }() - defer func() { - ls.stop() - lb.Stop() - }() + tss.ls.sls = sls + tss.ls.intervals = intervals creds := serverNameCheckCreds{ expected: besn, } ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{ - addr: lbLis.Addr().String(), + addrs: []string{tss.lbAddr}, })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) @@ -539,3 +509,90 @@ func TestServerExpiration(t *testing.T) { } cc.Close() } + +// When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list. +func TestBalancerDisconnects(t *testing.T) { + var ( + lbAddrs []string + lbs []*grpc.Server + ) + for i := 0; i < 3; i++ { + tss, cleanup, err := newLoadBalancer(1) + if err != nil { + t.Fatalf("failed to create new load balancer: %v", err) + } + defer cleanup() + + be := &lbpb.Server{ + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + } + var bes []*lbpb.Server + bes = append(bes, be) + sl := &lbpb.ServerList{ + Servers: bes, + } + tss.ls.sls = []*lbpb.ServerList{sl} + tss.ls.intervals = []time.Duration{0} + + lbAddrs = append(lbAddrs, tss.lbAddr) + lbs = append(lbs, tss.lb) + } + + creds := serverNameCheckCreds{ + expected: besn, + } + ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + resolver := &testNameResolver{ + addrs: lbAddrs[:2], + } + cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(resolver)), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + helloC := hwpb.NewGreeterClient(cc) + var message string + if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, ", helloC, err) + } else { + message = resp.Message + } + // The initial resolver update contains lbs[0] and lbs[1]. + // When lbs[0] is stopped, lbs[1] should be used. + lbs[0].Stop() + for { + if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, ", helloC, err) + } else if resp.Message != message { + // A new backend server should receive the request. + // The response contains the backend address, so the message should be different from the previous one. + message = resp.Message + break + } + time.Sleep(100 * time.Millisecond) + } + // Inject a update to add lbs[2] to resolved addresses. + resolver.inject([]*naming.Update{ + {Op: naming.Add, + Addr: lbAddrs[2], + Metadata: &Metadata{ + AddrType: GRPCLB, + ServerName: lbsn, + }, + }, + }) + // Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used. + lbs[1].Stop() + for { + if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, ", helloC, err) + } else if resp.Message != message { + // A new backend server should receive the request. + // The response contains the backend address, so the message should be different from the previous one. + break + } + time.Sleep(100 * time.Millisecond) + } + cc.Close() +}