mirror of https://github.com/grpc/grpc-go.git
internal/buffer: add Close method to the Unbounded buffer type (#6161)
This commit is contained in:
parent
ebeda756bc
commit
7dfd71831d
|
@ -188,7 +188,10 @@ func (b *rlsBalancer) run() {
|
|||
|
||||
for {
|
||||
select {
|
||||
case u := <-b.updateCh.Get():
|
||||
case u, ok := <-b.updateCh.Get():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.updateCh.Load()
|
||||
switch update := u.(type) {
|
||||
case childPolicyIDAndState:
|
||||
|
@ -450,6 +453,8 @@ func (b *rlsBalancer) Close() {
|
|||
b.dataCache.stop()
|
||||
b.cacheMu.Unlock()
|
||||
|
||||
b.updateCh.Close()
|
||||
|
||||
<-b.done.Done()
|
||||
}
|
||||
|
||||
|
|
|
@ -113,7 +113,10 @@ type subConnUpdate struct {
|
|||
func (ccb *ccBalancerWrapper) watcher() {
|
||||
for {
|
||||
select {
|
||||
case u := <-ccb.updateCh.Get():
|
||||
case u, ok := <-ccb.updateCh.Get():
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
ccb.updateCh.Load()
|
||||
if ccb.closed.HasFired() {
|
||||
break
|
||||
|
@ -155,8 +158,13 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat
|
|||
ccb.updateCh.Put(&ccStateUpdate{ccs: ccs})
|
||||
|
||||
var res interface{}
|
||||
var ok bool
|
||||
select {
|
||||
case res = <-ccb.resultCh.Get():
|
||||
case res, ok = <-ccb.resultCh.Get():
|
||||
if !ok {
|
||||
// The result channel is closed only when the balancer wrapper is closed.
|
||||
return nil
|
||||
}
|
||||
ccb.resultCh.Load()
|
||||
case <-ccb.closed.Done():
|
||||
// Return early if the balancer wrapper is closed while we are waiting for
|
||||
|
@ -296,6 +304,8 @@ func (ccb *ccBalancerWrapper) close() {
|
|||
|
||||
func (ccb *ccBalancerWrapper) handleClose() {
|
||||
ccb.balancer.Close()
|
||||
ccb.updateCh.Close()
|
||||
ccb.resultCh.Close()
|
||||
ccb.done.Fire()
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ import "sync"
|
|||
// internal/transport/transport.go for an example of this.
|
||||
type Unbounded struct {
|
||||
c chan interface{}
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
backlog []interface{}
|
||||
}
|
||||
|
@ -47,16 +48,18 @@ func NewUnbounded() *Unbounded {
|
|||
// Put adds t to the unbounded buffer.
|
||||
func (b *Unbounded) Put(t interface{}) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
if len(b.backlog) == 0 {
|
||||
select {
|
||||
case b.c <- t:
|
||||
b.mu.Unlock()
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
b.backlog = append(b.backlog, t)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
// Load sends the earliest buffered data, if any, onto the read channel
|
||||
|
@ -64,6 +67,10 @@ func (b *Unbounded) Put(t interface{}) {
|
|||
// value from the read channel.
|
||||
func (b *Unbounded) Load() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
if len(b.backlog) > 0 {
|
||||
select {
|
||||
case b.c <- b.backlog[0]:
|
||||
|
@ -72,7 +79,6 @@ func (b *Unbounded) Load() {
|
|||
default:
|
||||
}
|
||||
}
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
// Get returns a read channel on which values added to the buffer, via Put(),
|
||||
|
@ -80,6 +86,20 @@ func (b *Unbounded) Load() {
|
|||
//
|
||||
// Upon reading a value from this channel, users are expected to call Load() to
|
||||
// send the next buffered value onto the channel if there is any.
|
||||
//
|
||||
// If the unbounded buffer is closed, the read channel returned by this method
|
||||
// is closed.
|
||||
func (b *Unbounded) Get() <-chan interface{} {
|
||||
return b.c
|
||||
}
|
||||
|
||||
// Close closes the unbounded buffer.
|
||||
func (b *Unbounded) Close() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
b.closed = true
|
||||
close(b.c)
|
||||
}
|
||||
|
|
|
@ -119,3 +119,19 @@ func (s) TestMultipleWriters(t *testing.T) {
|
|||
t.Errorf("reads: %#v, wantReads: %#v", reads, wantReads)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClose closes the buffer and makes sure that nothing is sent after the
|
||||
// buffer is closed.
|
||||
func (s) TestClose(t *testing.T) {
|
||||
ub := NewUnbounded()
|
||||
ub.Close()
|
||||
if v, ok := <-ub.Get(); ok {
|
||||
t.Errorf("Unbounded.Get() = %v, want closed channel", v)
|
||||
}
|
||||
ub.Put(1)
|
||||
ub.Load()
|
||||
if v, ok := <-ub.Get(); ok {
|
||||
t.Errorf("Unbounded.Get() = %v, want closed channel", v)
|
||||
}
|
||||
ub.Close()
|
||||
}
|
||||
|
|
|
@ -56,8 +56,12 @@ func (t *CallbackSerializer) run(ctx context.Context) {
|
|||
for ctx.Err() == nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.callbacks.Close()
|
||||
return
|
||||
case callback := <-t.callbacks.Get():
|
||||
case callback, ok := <-t.callbacks.Get():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
t.callbacks.Load()
|
||||
callback.(func(ctx context.Context))(ctx)
|
||||
}
|
||||
|
|
|
@ -1337,7 +1337,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
|
|||
|
||||
// setGoAwayReason sets the value of t.goAwayReason based
|
||||
// on the GoAway frame received.
|
||||
// It expects a lock on transport's mutext to be held by
|
||||
// It expects a lock on transport's mutex to be held by
|
||||
// the caller.
|
||||
func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) {
|
||||
t.goAwayReason = GoAwayNoReason
|
||||
|
|
|
@ -426,7 +426,10 @@ func (b *cdsBalancer) handleWatchUpdate(update clusterHandlerUpdate) {
|
|||
func (b *cdsBalancer) run() {
|
||||
for {
|
||||
select {
|
||||
case u := <-b.updateCh.Get():
|
||||
case u, ok := <-b.updateCh.Get():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.updateCh.Load()
|
||||
switch update := u.(type) {
|
||||
case *ccUpdate:
|
||||
|
@ -466,6 +469,7 @@ func (b *cdsBalancer) run() {
|
|||
if b.cachedIdentity != nil {
|
||||
b.cachedIdentity.Close()
|
||||
}
|
||||
b.updateCh.Close()
|
||||
b.logger.Infof("Shutdown")
|
||||
b.done.Fire()
|
||||
return
|
||||
|
|
|
@ -333,6 +333,7 @@ func (b *clusterImplBalancer) Close() {
|
|||
b.childLB = nil
|
||||
b.childState = balancer.State{}
|
||||
}
|
||||
b.pickerUpdateCh.Close()
|
||||
<-b.done.Done()
|
||||
b.logger.Infof("Shutdown")
|
||||
}
|
||||
|
@ -506,7 +507,10 @@ func (b *clusterImplBalancer) run() {
|
|||
defer b.done.Fire()
|
||||
for {
|
||||
select {
|
||||
case update := <-b.pickerUpdateCh.Get():
|
||||
case update, ok := <-b.pickerUpdateCh.Get():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.pickerUpdateCh.Load()
|
||||
b.mu.Lock()
|
||||
if b.closed.HasFired() {
|
||||
|
|
|
@ -265,7 +265,10 @@ func (b *clusterResolverBalancer) handleErrorFromUpdate(err error, fromParent bo
|
|||
func (b *clusterResolverBalancer) run() {
|
||||
for {
|
||||
select {
|
||||
case u := <-b.updateCh.Get():
|
||||
case u, ok := <-b.updateCh.Get():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.updateCh.Load()
|
||||
switch update := u.(type) {
|
||||
case *ccUpdate:
|
||||
|
@ -303,6 +306,7 @@ func (b *clusterResolverBalancer) run() {
|
|||
b.child.Close()
|
||||
b.child = nil
|
||||
}
|
||||
b.updateCh.Close()
|
||||
// This is the *ONLY* point of return from this function.
|
||||
b.logger.Infof("Shutdown")
|
||||
b.done.Fire()
|
||||
|
|
|
@ -362,6 +362,9 @@ func (b *outlierDetectionBalancer) Close() {
|
|||
b.child.Close()
|
||||
b.childMu.Unlock()
|
||||
|
||||
b.scUpdateCh.Close()
|
||||
b.pickerUpdateCh.Close()
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.intervalTimer != nil {
|
||||
|
@ -681,7 +684,10 @@ func (b *outlierDetectionBalancer) run() {
|
|||
defer b.done.Fire()
|
||||
for {
|
||||
select {
|
||||
case update := <-b.scUpdateCh.Get():
|
||||
case update, ok := <-b.scUpdateCh.Get():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.scUpdateCh.Load()
|
||||
if b.closed.HasFired() { // don't send SubConn updates to child after the balancer has been closed
|
||||
return
|
||||
|
@ -692,7 +698,10 @@ func (b *outlierDetectionBalancer) run() {
|
|||
case *ejectionUpdate:
|
||||
b.handleEjectedUpdate(u)
|
||||
}
|
||||
case update := <-b.pickerUpdateCh.Get():
|
||||
case update, ok := <-b.pickerUpdateCh.Get():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.pickerUpdateCh.Load()
|
||||
if b.closed.HasFired() { // don't send picker updates to grpc after the balancer has been closed
|
||||
return
|
||||
|
|
|
@ -205,6 +205,7 @@ func (b *priorityBalancer) UpdateSubConnState(sc balancer.SubConn, state balance
|
|||
|
||||
func (b *priorityBalancer) Close() {
|
||||
b.bg.Close()
|
||||
b.childBalancerStateUpdate.Close()
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
@ -247,7 +248,10 @@ type resumePickerUpdates struct {
|
|||
func (b *priorityBalancer) run() {
|
||||
for {
|
||||
select {
|
||||
case u := <-b.childBalancerStateUpdate.Get():
|
||||
case u, ok := <-b.childBalancerStateUpdate.Get():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.childBalancerStateUpdate.Load()
|
||||
// Needs to handle state update in a goroutine, because each state
|
||||
// update needs to start/close child policy, could result in
|
||||
|
|
|
@ -400,7 +400,11 @@ func (t *Transport) send(ctx context.Context) {
|
|||
continue
|
||||
}
|
||||
sendNodeProto = false
|
||||
case u := <-t.adsRequestCh.Get():
|
||||
case u, ok := <-t.adsRequestCh.Get():
|
||||
if !ok {
|
||||
// No requests will be sent after the adsRequestCh buffer is closed.
|
||||
return
|
||||
}
|
||||
t.adsRequestCh.Load()
|
||||
|
||||
var (
|
||||
|
@ -621,6 +625,7 @@ func (t *Transport) processAckRequest(ack *ackRequest, stream grpc.ClientStream)
|
|||
func (t *Transport) Close() {
|
||||
t.adsRunnerCancel()
|
||||
<-t.adsRunnerDoneCh
|
||||
t.adsRequestCh.Close()
|
||||
t.cc.Close()
|
||||
}
|
||||
|
||||
|
|
|
@ -272,6 +272,7 @@ func (s *GRPCServer) Serve(lis net.Listener) error {
|
|||
// need to explicitly close the listener. Cancellation of the xDS watch
|
||||
// is handled by the listenerWrapper.
|
||||
lw.Close()
|
||||
modeUpdateCh.Close()
|
||||
return nil
|
||||
case <-goodUpdateCh:
|
||||
}
|
||||
|
@ -295,7 +296,10 @@ func (s *GRPCServer) handleServingModeChanges(updateCh *buffer.Unbounded) {
|
|||
select {
|
||||
case <-s.quit.Done():
|
||||
return
|
||||
case u := <-updateCh.Get():
|
||||
case u, ok := <-updateCh.Get():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
updateCh.Load()
|
||||
args := u.(*modeChangeArgs)
|
||||
if args.mode == connectivity.ServingModeNotServing {
|
||||
|
|
Loading…
Reference in New Issue