mirror of https://github.com/grpc/grpc-go.git
buffer & grpcsync: various cleanups and improvements (#6785)
This commit is contained in:
parent
424db25679
commit
b98104ec5a
|
@ -18,7 +18,10 @@
|
|||
// Package buffer provides an implementation of an unbounded buffer.
|
||||
package buffer
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Unbounded is an implementation of an unbounded buffer which does not use
|
||||
// extra goroutines. This is typically used for passing updates from one entity
|
||||
|
@ -36,6 +39,7 @@ import "sync"
|
|||
type Unbounded struct {
|
||||
c chan any
|
||||
closed bool
|
||||
closing bool
|
||||
mu sync.Mutex
|
||||
backlog []any
|
||||
}
|
||||
|
@ -45,32 +49,32 @@ func NewUnbounded() *Unbounded {
|
|||
return &Unbounded{c: make(chan any, 1)}
|
||||
}
|
||||
|
||||
var errBufferClosed = errors.New("Put called on closed buffer.Unbounded")
|
||||
|
||||
// Put adds t to the unbounded buffer.
|
||||
func (b *Unbounded) Put(t any) {
|
||||
func (b *Unbounded) Put(t any) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return
|
||||
if b.closing {
|
||||
return errBufferClosed
|
||||
}
|
||||
if len(b.backlog) == 0 {
|
||||
select {
|
||||
case b.c <- t:
|
||||
return
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
b.backlog = append(b.backlog, t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load sends the earliest buffered data, if any, onto the read channel
|
||||
// returned by Get(). Users are expected to call this every time they read a
|
||||
// Load sends the earliest buffered data, if any, onto the read channel returned
|
||||
// by Get(). Users are expected to call this every time they successfully read a
|
||||
// value from the read channel.
|
||||
func (b *Unbounded) Load() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
if len(b.backlog) > 0 {
|
||||
select {
|
||||
case b.c <- b.backlog[0]:
|
||||
|
@ -78,6 +82,8 @@ func (b *Unbounded) Load() {
|
|||
b.backlog = b.backlog[1:]
|
||||
default:
|
||||
}
|
||||
} else if b.closing && !b.closed {
|
||||
close(b.c)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,18 +94,23 @@ func (b *Unbounded) Load() {
|
|||
// send the next buffered value onto the channel if there is any.
|
||||
//
|
||||
// If the unbounded buffer is closed, the read channel returned by this method
|
||||
// is closed.
|
||||
// is closed after all data is drained.
|
||||
func (b *Unbounded) Get() <-chan any {
|
||||
return b.c
|
||||
}
|
||||
|
||||
// Close closes the unbounded buffer.
|
||||
// Close closes the unbounded buffer. No subsequent data may be Put(), and the
|
||||
// channel returned from Get() will be closed after all the data is read and
|
||||
// Load() is called for the final time.
|
||||
func (b *Unbounded) Close() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
if b.closing {
|
||||
return
|
||||
}
|
||||
b.closed = true
|
||||
close(b.c)
|
||||
b.closing = true
|
||||
if len(b.backlog) == 0 {
|
||||
b.closed = true
|
||||
close(b.c)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ func init() {
|
|||
}
|
||||
|
||||
// TestSingleWriter starts one reader and one writer goroutine and makes sure
|
||||
// that the reader gets all the value added to the buffer by the writer.
|
||||
// that the reader gets all the values added to the buffer by the writer.
|
||||
func (s) TestSingleWriter(t *testing.T) {
|
||||
ub := NewUnbounded()
|
||||
reads := []int{}
|
||||
|
@ -124,14 +124,25 @@ func (s) TestMultipleWriters(t *testing.T) {
|
|||
// buffer is closed.
|
||||
func (s) TestClose(t *testing.T) {
|
||||
ub := NewUnbounded()
|
||||
ub.Close()
|
||||
if v, ok := <-ub.Get(); ok {
|
||||
t.Errorf("Unbounded.Get() = %v, want closed channel", v)
|
||||
if err := ub.Put(1); err != nil {
|
||||
t.Fatalf("Unbounded.Put() = %v; want nil", err)
|
||||
}
|
||||
ub.Close()
|
||||
if err := ub.Put(1); err == nil {
|
||||
t.Fatalf("Unbounded.Put() = <nil>; want non-nil error")
|
||||
}
|
||||
if v, ok := <-ub.Get(); !ok {
|
||||
t.Errorf("Unbounded.Get() = %v, %v, want %v, %v", v, ok, 1, true)
|
||||
}
|
||||
if err := ub.Put(1); err == nil {
|
||||
t.Fatalf("Unbounded.Put() = <nil>; want non-nil error")
|
||||
}
|
||||
ub.Put(1)
|
||||
ub.Load()
|
||||
if v, ok := <-ub.Get(); ok {
|
||||
t.Errorf("Unbounded.Get() = %v, want closed channel", v)
|
||||
}
|
||||
ub.Close()
|
||||
if err := ub.Put(1); err == nil {
|
||||
t.Fatalf("Unbounded.Put() = <nil>; want non-nil error")
|
||||
}
|
||||
ub.Close() // ignored
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ package grpcsync
|
|||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc/internal/buffer"
|
||||
)
|
||||
|
@ -38,8 +37,6 @@ type CallbackSerializer struct {
|
|||
done chan struct{}
|
||||
|
||||
callbacks *buffer.Unbounded
|
||||
closedMu sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewCallbackSerializer returns a new CallbackSerializer instance. The provided
|
||||
|
@ -65,56 +62,34 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
|
|||
// callbacks to be executed by the serializer. It is not possible to add
|
||||
// callbacks once the context passed to NewCallbackSerializer is cancelled.
|
||||
func (cs *CallbackSerializer) Schedule(f func(ctx context.Context)) bool {
|
||||
cs.closedMu.Lock()
|
||||
defer cs.closedMu.Unlock()
|
||||
|
||||
if cs.closed {
|
||||
return false
|
||||
}
|
||||
cs.callbacks.Put(f)
|
||||
return true
|
||||
return cs.callbacks.Put(f) == nil
|
||||
}
|
||||
|
||||
func (cs *CallbackSerializer) run(ctx context.Context) {
|
||||
var backlog []func(context.Context)
|
||||
|
||||
defer close(cs.done)
|
||||
|
||||
// TODO: when Go 1.21 is the oldest supported version, this loop and Close
|
||||
// can be replaced with:
|
||||
//
|
||||
// context.AfterFunc(ctx, cs.callbacks.Close)
|
||||
for ctx.Err() == nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Do nothing here. Next iteration of the for loop will not happen,
|
||||
// since ctx.Err() would be non-nil.
|
||||
case callback, ok := <-cs.callbacks.Get():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case cb := <-cs.callbacks.Get():
|
||||
cs.callbacks.Load()
|
||||
callback.(func(ctx context.Context))(ctx)
|
||||
cb.(func(context.Context))(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch pending callbacks if any, and execute them before returning from
|
||||
// this method and closing cs.done.
|
||||
cs.closedMu.Lock()
|
||||
cs.closed = true
|
||||
backlog = cs.fetchPendingCallbacks()
|
||||
// Close the buffer to prevent new callbacks from being added.
|
||||
cs.callbacks.Close()
|
||||
cs.closedMu.Unlock()
|
||||
for _, b := range backlog {
|
||||
b(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CallbackSerializer) fetchPendingCallbacks() []func(context.Context) {
|
||||
var backlog []func(context.Context)
|
||||
for {
|
||||
select {
|
||||
case b := <-cs.callbacks.Get():
|
||||
backlog = append(backlog, b.(func(context.Context)))
|
||||
cs.callbacks.Load()
|
||||
default:
|
||||
return backlog
|
||||
}
|
||||
// Run all pending callbacks.
|
||||
for cb := range cs.callbacks.Get() {
|
||||
cs.callbacks.Load()
|
||||
cb.(func(context.Context))(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue