internal/buffer: add Close method to the Unbounded buffer type (#6161)

This commit is contained in:
Ernest Nguyen Hung 2023-04-19 01:53:59 +02:00 committed by GitHub
parent ebeda756bc
commit 7dfd71831d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 105 additions and 16 deletions

View File

@ -188,7 +188,10 @@ func (b *rlsBalancer) run() {
for {
select {
case u := <-b.updateCh.Get():
case u, ok := <-b.updateCh.Get():
if !ok {
return
}
b.updateCh.Load()
switch update := u.(type) {
case childPolicyIDAndState:
@ -450,6 +453,8 @@ func (b *rlsBalancer) Close() {
b.dataCache.stop()
b.cacheMu.Unlock()
b.updateCh.Close()
<-b.done.Done()
}

View File

@ -113,7 +113,10 @@ type subConnUpdate struct {
func (ccb *ccBalancerWrapper) watcher() {
for {
select {
case u := <-ccb.updateCh.Get():
case u, ok := <-ccb.updateCh.Get():
if !ok {
break
}
ccb.updateCh.Load()
if ccb.closed.HasFired() {
break
@ -155,8 +158,13 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat
ccb.updateCh.Put(&ccStateUpdate{ccs: ccs})
var res interface{}
var ok bool
select {
case res = <-ccb.resultCh.Get():
case res, ok = <-ccb.resultCh.Get():
if !ok {
// The result channel is closed only when the balancer wrapper is closed.
return nil
}
ccb.resultCh.Load()
case <-ccb.closed.Done():
// Return early if the balancer wrapper is closed while we are waiting for
@ -296,6 +304,8 @@ func (ccb *ccBalancerWrapper) close() {
func (ccb *ccBalancerWrapper) handleClose() {
ccb.balancer.Close()
ccb.updateCh.Close()
ccb.resultCh.Close()
ccb.done.Fire()
}

View File

@ -35,6 +35,7 @@ import "sync"
// internal/transport/transport.go for an example of this.
type Unbounded struct {
c chan interface{}
closed bool
mu sync.Mutex
backlog []interface{}
}
@ -47,16 +48,18 @@ func NewUnbounded() *Unbounded {
// Put adds t to the unbounded buffer.
func (b *Unbounded) Put(t interface{}) {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
if len(b.backlog) == 0 {
select {
case b.c <- t:
b.mu.Unlock()
return
default:
}
}
b.backlog = append(b.backlog, t)
b.mu.Unlock()
}
// Load sends the earliest buffered data, if any, onto the read channel
@ -64,6 +67,10 @@ func (b *Unbounded) Put(t interface{}) {
// value from the read channel.
func (b *Unbounded) Load() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
if len(b.backlog) > 0 {
select {
case b.c <- b.backlog[0]:
@ -72,7 +79,6 @@ func (b *Unbounded) Load() {
default:
}
}
b.mu.Unlock()
}
// Get returns a read channel on which values added to the buffer, via Put(),
@ -80,6 +86,20 @@ func (b *Unbounded) Load() {
//
// Upon reading a value from this channel, users are expected to call Load() to
// send the next buffered value onto the channel if there is any.
//
// If the unbounded buffer is closed, the read channel returned by this method
// is closed.
func (b *Unbounded) Get() <-chan interface{} {
return b.c
}
// Close closes the unbounded buffer.
func (b *Unbounded) Close() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
b.closed = true
close(b.c)
}

View File

@ -119,3 +119,19 @@ func (s) TestMultipleWriters(t *testing.T) {
t.Errorf("reads: %#v, wantReads: %#v", reads, wantReads)
}
}
// TestClose closes the buffer and makes sure that nothing is sent after the
// buffer is closed.
func (s) TestClose(t *testing.T) {
ub := NewUnbounded()
ub.Close()
if v, ok := <-ub.Get(); ok {
t.Errorf("Unbounded.Get() = %v, want closed channel", v)
}
ub.Put(1)
ub.Load()
if v, ok := <-ub.Get(); ok {
t.Errorf("Unbounded.Get() = %v, want closed channel", v)
}
ub.Close()
}

View File

@ -56,8 +56,12 @@ func (t *CallbackSerializer) run(ctx context.Context) {
for ctx.Err() == nil {
select {
case <-ctx.Done():
t.callbacks.Close()
return
case callback := <-t.callbacks.Get():
case callback, ok := <-t.callbacks.Get():
if !ok {
return
}
t.callbacks.Load()
callback.(func(ctx context.Context))(ctx)
}

View File

@ -1337,7 +1337,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// setGoAwayReason sets the value of t.goAwayReason based
// on the GoAway frame received.
// It expects a lock on transport's mutext to be held by
// It expects a lock on transport's mutex to be held by
// the caller.
func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) {
t.goAwayReason = GoAwayNoReason

View File

@ -426,7 +426,10 @@ func (b *cdsBalancer) handleWatchUpdate(update clusterHandlerUpdate) {
func (b *cdsBalancer) run() {
for {
select {
case u := <-b.updateCh.Get():
case u, ok := <-b.updateCh.Get():
if !ok {
return
}
b.updateCh.Load()
switch update := u.(type) {
case *ccUpdate:
@ -466,6 +469,7 @@ func (b *cdsBalancer) run() {
if b.cachedIdentity != nil {
b.cachedIdentity.Close()
}
b.updateCh.Close()
b.logger.Infof("Shutdown")
b.done.Fire()
return

View File

@ -333,6 +333,7 @@ func (b *clusterImplBalancer) Close() {
b.childLB = nil
b.childState = balancer.State{}
}
b.pickerUpdateCh.Close()
<-b.done.Done()
b.logger.Infof("Shutdown")
}
@ -506,7 +507,10 @@ func (b *clusterImplBalancer) run() {
defer b.done.Fire()
for {
select {
case update := <-b.pickerUpdateCh.Get():
case update, ok := <-b.pickerUpdateCh.Get():
if !ok {
return
}
b.pickerUpdateCh.Load()
b.mu.Lock()
if b.closed.HasFired() {

View File

@ -265,7 +265,10 @@ func (b *clusterResolverBalancer) handleErrorFromUpdate(err error, fromParent bo
func (b *clusterResolverBalancer) run() {
for {
select {
case u := <-b.updateCh.Get():
case u, ok := <-b.updateCh.Get():
if !ok {
return
}
b.updateCh.Load()
switch update := u.(type) {
case *ccUpdate:
@ -303,6 +306,7 @@ func (b *clusterResolverBalancer) run() {
b.child.Close()
b.child = nil
}
b.updateCh.Close()
// This is the *ONLY* point of return from this function.
b.logger.Infof("Shutdown")
b.done.Fire()

View File

@ -362,6 +362,9 @@ func (b *outlierDetectionBalancer) Close() {
b.child.Close()
b.childMu.Unlock()
b.scUpdateCh.Close()
b.pickerUpdateCh.Close()
b.mu.Lock()
defer b.mu.Unlock()
if b.intervalTimer != nil {
@ -681,7 +684,10 @@ func (b *outlierDetectionBalancer) run() {
defer b.done.Fire()
for {
select {
case update := <-b.scUpdateCh.Get():
case update, ok := <-b.scUpdateCh.Get():
if !ok {
return
}
b.scUpdateCh.Load()
if b.closed.HasFired() { // don't send SubConn updates to child after the balancer has been closed
return
@ -692,7 +698,10 @@ func (b *outlierDetectionBalancer) run() {
case *ejectionUpdate:
b.handleEjectedUpdate(u)
}
case update := <-b.pickerUpdateCh.Get():
case update, ok := <-b.pickerUpdateCh.Get():
if !ok {
return
}
b.pickerUpdateCh.Load()
if b.closed.HasFired() { // don't send picker updates to grpc after the balancer has been closed
return

View File

@ -205,6 +205,7 @@ func (b *priorityBalancer) UpdateSubConnState(sc balancer.SubConn, state balance
func (b *priorityBalancer) Close() {
b.bg.Close()
b.childBalancerStateUpdate.Close()
b.mu.Lock()
defer b.mu.Unlock()
@ -247,7 +248,10 @@ type resumePickerUpdates struct {
func (b *priorityBalancer) run() {
for {
select {
case u := <-b.childBalancerStateUpdate.Get():
case u, ok := <-b.childBalancerStateUpdate.Get():
if !ok {
return
}
b.childBalancerStateUpdate.Load()
// Needs to handle state update in a goroutine, because each state
// update needs to start/close child policy, could result in

View File

@ -400,7 +400,11 @@ func (t *Transport) send(ctx context.Context) {
continue
}
sendNodeProto = false
case u := <-t.adsRequestCh.Get():
case u, ok := <-t.adsRequestCh.Get():
if !ok {
// No requests will be sent after the adsRequestCh buffer is closed.
return
}
t.adsRequestCh.Load()
var (
@ -621,6 +625,7 @@ func (t *Transport) processAckRequest(ack *ackRequest, stream grpc.ClientStream)
func (t *Transport) Close() {
t.adsRunnerCancel()
<-t.adsRunnerDoneCh
t.adsRequestCh.Close()
t.cc.Close()
}

View File

@ -272,6 +272,7 @@ func (s *GRPCServer) Serve(lis net.Listener) error {
// need to explicitly close the listener. Cancellation of the xDS watch
// is handled by the listenerWrapper.
lw.Close()
modeUpdateCh.Close()
return nil
case <-goodUpdateCh:
}
@ -295,7 +296,10 @@ func (s *GRPCServer) handleServingModeChanges(updateCh *buffer.Unbounded) {
select {
case <-s.quit.Done():
return
case u := <-updateCh.Get():
case u, ok := <-updateCh.Get():
if !ok {
return
}
updateCh.Load()
args := u.(*modeChangeArgs)
if args.mode == connectivity.ServingModeNotServing {