From 98c0b8aa3e6ae828ecaa365050874470cc5d15cc Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Wed, 25 May 2016 15:55:03 -0700 Subject: [PATCH] Fix a race conditon and add some small touchups --- balancer.go | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/balancer.go b/balancer.go index da6cde1db..80220cf66 100644 --- a/balancer.go +++ b/balancer.go @@ -129,15 +129,15 @@ func downErrorf(timeout, temporary bool, format string, a ...interface{}) downEr } } -// RoundRobin returns a Balancer that selects addresses round-robin. It starts to watch -// the name resolution updates. +// RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch +// the name resolution updates and updates the addresses available correspondingly. func RoundRobin(r naming.Resolver) Balancer { return &roundRobin{r: r} } type roundRobin struct { r naming.Resolver - w []naming.Watcher + w naming.Watcher open []Address // all the addresses the client should potentially connect mu sync.Mutex addrCh chan []Address // the channel to notify gRPC internals the list of addresses the client should connect to. @@ -147,8 +147,8 @@ type roundRobin struct { done bool // The Balancer is closed. } -func (rr *roundRobin) watchAddrUpdates(w naming.Watcher) error { - updates, err := w.Next() +func (rr *roundRobin) watchAddrUpdates() error { + updates, err := rr.w.Next() if err != nil { grpclog.Println("grpc: the naming watcher stops working due to %v.", err) return err @@ -187,7 +187,7 @@ func (rr *roundRobin) watchAddrUpdates(w naming.Watcher) error { } } // Make a copy of rr.open and write it onto rr.addrCh so that gRPC internals gets notified. - open := make([]Address, len(rr.open)) + open := make([]Address, len(rr.open), len(rr.open)) copy(open, rr.open) if rr.done { return ErrClientConnClosing @@ -206,11 +206,11 @@ func (rr *roundRobin) Start(target string) error { if err != nil { return err } - rr.w = []naming.Watcher{w} + rr.w = w rr.addrCh = make(chan []Address) go func() { for { - if err := rr.watchAddrUpdates(w); err != nil { + if err := rr.watchAddrUpdates(); err != nil { return } } @@ -218,7 +218,7 @@ func (rr *roundRobin) Start(target string) error { return nil } -// Up appends addr to the end of rr.addrs and sends notification if there +// Up appends addr to the end of rr.connected and sends notification if there // are pending Get() calls. func (rr *roundRobin) Up(addr Address) func(error) { rr.mu.Lock() @@ -241,7 +241,7 @@ func (rr *roundRobin) Up(addr Address) func(error) { } } -// down removes addr from rr.addrs and moves the remaining addrs forward. +// down removes addr from rr.connected and moves the remaining addrs forward. func (rr *roundRobin) down(addr Address, err error) { rr.mu.Lock() defer rr.mu.Unlock() @@ -295,6 +295,12 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad } if len(rr.connected) == 0 { // The newly added addr got removed by Down() again. + if rr.waitCh == nil { + ch = make(chan struct{}) + rr.waitCh = ch + } else { + ch = rr.waitCh + } rr.mu.Unlock() continue } @@ -317,8 +323,8 @@ func (rr *roundRobin) Close() error { rr.mu.Lock() defer rr.mu.Unlock() rr.done = true - for _, w := range rr.w { - w.Close() + if rr.w != nil { + rr.w.Close() } if rr.waitCh != nil { close(rr.waitCh)