address the comments

This commit is contained in:
iamqizhao 2016-09-21 17:54:48 -07:00
parent 3f1de24002
commit e77c5bbb41
2 changed files with 26 additions and 17 deletions

View File

@ -70,6 +70,7 @@ type addrInfo struct {
type balancer struct { type balancer struct {
r naming.Resolver r naming.Resolver
mu sync.Mutex mu sync.Mutex
seq int // a sequence number to make sure addrCh does not get stale addresses.
w naming.Watcher w naming.Watcher
addrCh chan []grpc.Address addrCh chan []grpc.Address
rbs []remoteBalancerInfo rbs []remoteBalancerInfo
@ -84,12 +85,12 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo
if err != nil { if err != nil {
return err return err
} }
var bAddr remoteBalancerInfo
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
if b.done { if b.done {
return grpc.ErrClientConnClosing return grpc.ErrClientConnClosing
} }
var bAddr remoteBalancerInfo
if len(b.rbs) > 0 { if len(b.rbs) > 0 {
bAddr = b.rbs[0] bAddr = b.rbs[0]
} }
@ -102,7 +103,7 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo
case naming.Add: case naming.Add:
var exist bool var exist bool
for _, v := range b.rbs { for _, v := range b.rbs {
// TODO: Is the same addr with different different server name a different balancer? // TODO: Is the same addr with different server name a different balancer?
if addr == v.addr { if addr == v.addr {
exist = true exist = true
break break
@ -139,7 +140,7 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo
return nil return nil
} }
func (b *balancer) processServerList(l *lbpb.ServerList) { func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
servers := l.GetServers() servers := l.GetServers()
var ( var (
sl []addrInfo sl []addrInfo
@ -159,7 +160,7 @@ func (b *balancer) processServerList(l *lbpb.ServerList) {
} }
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
if b.done { if b.done || seq < b.seq {
return return
} }
if len(sl) > 0 { if len(sl) > 0 {
@ -172,12 +173,6 @@ func (b *balancer) processServerList(l *lbpb.ServerList) {
} }
func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) { func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) {
b.mu.Lock()
if b.done {
b.mu.Unlock()
return
}
b.mu.Unlock()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false)) stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false))
@ -185,6 +180,14 @@ func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool)
grpclog.Printf("Failed to perform RPC to the remote balancer %v", err) grpclog.Printf("Failed to perform RPC to the remote balancer %v", err)
return return
} }
b.mu.Lock()
if b.done {
b.mu.Unlock()
return
}
b.seq++
seq := b.seq
b.mu.Unlock()
initReq := &lbpb.LoadBalanceRequest{ initReq := &lbpb.LoadBalanceRequest{
LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{ LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{
InitialRequest: new(lbpb.InitialLoadBalanceRequest), InitialRequest: new(lbpb.InitialLoadBalanceRequest),
@ -217,7 +220,7 @@ func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool)
break break
} }
if serverList := reply.GetServerList(); serverList != nil { if serverList := reply.GetServerList(); serverList != nil {
b.processServerList(serverList) b.processServerList(serverList, seq)
} }
} }
return true return true
@ -307,6 +310,9 @@ func (b *balancer) down(addr grpc.Address, err error) {
func (b *balancer) Up(addr grpc.Address) func(error) { func (b *balancer) Up(addr grpc.Address) func(error) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
if b.done {
return nil
}
var cnt int var cnt int
for _, a := range b.addrs { for _, a := range b.addrs {
if a.addr == addr { if a.addr == addr {

View File

@ -1,6 +1,6 @@
/* /*
* *
* Copyright 2014, Google Inc. * Copyright 2016, Google Inc.
* All rights reserved. * All rights reserved.
* *
* Redistribution and use in source and binary forms, with or without * Redistribution and use in source and binary forms, with or without
@ -57,12 +57,15 @@ type testWatcher struct {
} }
func (w *testWatcher) Next() (updates []*naming.Update, err error) { func (w *testWatcher) Next() (updates []*naming.Update, err error) {
n := <-w.side n, ok := <-w.side
if n == 0 { if !ok {
return nil, fmt.Errorf("w.side is closed") return nil, fmt.Errorf("w.side is closed")
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
u := <-w.update u, ok := <-w.update
if !ok {
break
}
if u != nil { if u != nil {
updates = append(updates, u) updates = append(updates, u)
} }
@ -158,11 +161,11 @@ func stopBackends(servers []*grpc.Server) {
} }
} }
func TestGrpcLB(t *testing.T) { func TestGRPCLB(t *testing.T) {
// Start a backend. // Start a backend.
beLis, err := net.Listen("tcp", "localhost:0") beLis, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {
t.Fatalf("fadjf") t.Fatalf("Failed to listen %v", err)
} }
backends := startBackends(beLis) backends := startBackends(beLis)
defer stopBackends(backends) defer stopBackends(backends)