grpc: use CallbackSerializer in resolver_wrapper (#6234)

This commit is contained in:
Easwar Swaminathan 2023-05-04 16:05:13 -07:00 committed by GitHub
parent 47b3c5545c
commit ccad7b7570
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 100 additions and 92 deletions

View File

@ -31,6 +31,12 @@ import (
// //
// This type is safe for concurrent access. // This type is safe for concurrent access.
type CallbackSerializer struct { type CallbackSerializer struct {
// Done is closed once the serializer is shut down completely, i.e a
// scheduled callback, if any, that was running when the context passed to
// NewCallbackSerializer is cancelled, has completed and the serializer has
// deallocated all its resources.
Done chan struct{}
callbacks *buffer.Unbounded callbacks *buffer.Unbounded
} }
@ -39,7 +45,10 @@ type CallbackSerializer struct {
// provided context to shutdown the CallbackSerializer. It is guaranteed that no // provided context to shutdown the CallbackSerializer. It is guaranteed that no
// callbacks will be executed once this context is canceled. // callbacks will be executed once this context is canceled.
func NewCallbackSerializer(ctx context.Context) *CallbackSerializer { func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
t := &CallbackSerializer{callbacks: buffer.NewUnbounded()} t := &CallbackSerializer{
Done: make(chan struct{}),
callbacks: buffer.NewUnbounded(),
}
go t.run(ctx) go t.run(ctx)
return t return t
} }
@ -53,6 +62,7 @@ func (t *CallbackSerializer) Schedule(f func(ctx context.Context)) {
} }
func (t *CallbackSerializer) run(ctx context.Context) { func (t *CallbackSerializer) run(ctx context.Context) {
defer close(t.Done)
for ctx.Err() == nil { for ctx.Err() == nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():

View File

@ -144,19 +144,13 @@ func (s) TestCallbackSerializer_Schedule_Close(t *testing.T) {
cs := NewCallbackSerializer(ctx) cs := NewCallbackSerializer(ctx)
// Schedule a callback which blocks until the context passed to it is // Schedule a callback which blocks until the context passed to it is
// canceled. It also closes a couple of channels to signal that it started // canceled. It also closes a channel to signal that it has started.
// and finished respectively.
firstCallbackStartedCh := make(chan struct{}) firstCallbackStartedCh := make(chan struct{})
firstCallbackFinishCh := make(chan struct{})
cs.Schedule(func(ctx context.Context) { cs.Schedule(func(ctx context.Context) {
close(firstCallbackStartedCh) close(firstCallbackStartedCh)
<-ctx.Done() <-ctx.Done()
close(firstCallbackFinishCh)
}) })
// Wait for the first callback to start before scheduling the others.
<-firstCallbackStartedCh
// Schedule a bunch of callbacks. These should not be exeuted since the first // Schedule a bunch of callbacks. These should not be exeuted since the first
// one started earlier is blocked. // one started earlier is blocked.
const numCallbacks = 10 const numCallbacks = 10
@ -174,11 +168,14 @@ func (s) TestCallbackSerializer_Schedule_Close(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// Wait for the first callback to start before closing the scheduler.
<-firstCallbackStartedCh
// Cancel the context which will unblock the first callback. None of the // Cancel the context which will unblock the first callback. None of the
// other callbacks (which have not started executing at this point) should // other callbacks (which have not started executing at this point) should
// be executed after this. // be executed after this.
cancel() cancel()
<-firstCallbackFinishCh <-cs.Done
// Ensure that the newer callbacks are not executed. // Ensure that the newer callbacks are not executed.
select { select {

View File

@ -19,8 +19,8 @@
package grpc package grpc
import ( import (
"context"
"strings" "strings"
"sync"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
@ -42,14 +42,16 @@ type ccResolverWrapper struct {
// The following fields are initialized when the wrapper is created and are // The following fields are initialized when the wrapper is created and are
// read-only afterwards, and therefore can be accessed without a mutex. // read-only afterwards, and therefore can be accessed without a mutex.
cc resolverStateUpdater cc resolverStateUpdater
done *grpcsync.Event
channelzID *channelz.Identifier channelzID *channelz.Identifier
ignoreServiceConfig bool ignoreServiceConfig bool
resolverMu sync.Mutex // Outgoing (gRPC --> resolver) and incoming (resolver --> gRPC) calls are
// guaranteed to execute in a mutually exclusive manner as they are
// scheduled on the CallbackSerializer. Fields accessed *only* in serializer
// callbacks, can therefore be accessed without a mutex.
serializer *grpcsync.CallbackSerializer
serializerCancel context.CancelFunc
resolver resolver.Resolver resolver resolver.Resolver
incomingMu sync.Mutex // Synchronizes all the incoming calls.
curState resolver.State curState resolver.State
} }
@ -65,91 +67,86 @@ type ccResolverWrapperOpts struct {
// newCCResolverWrapper uses the resolver.Builder to build a Resolver and // newCCResolverWrapper uses the resolver.Builder to build a Resolver and
// returns a ccResolverWrapper object which wraps the newly built resolver. // returns a ccResolverWrapper object which wraps the newly built resolver.
func newCCResolverWrapper(cc resolverStateUpdater, opts ccResolverWrapperOpts) (*ccResolverWrapper, error) { func newCCResolverWrapper(cc resolverStateUpdater, opts ccResolverWrapperOpts) (*ccResolverWrapper, error) {
ctx, cancel := context.WithCancel(context.Background())
ccr := &ccResolverWrapper{ ccr := &ccResolverWrapper{
cc: cc, cc: cc,
done: grpcsync.NewEvent(),
channelzID: opts.channelzID, channelzID: opts.channelzID,
ignoreServiceConfig: opts.bOpts.DisableServiceConfig, ignoreServiceConfig: opts.bOpts.DisableServiceConfig,
serializer: grpcsync.NewCallbackSerializer(ctx),
serializerCancel: cancel,
} }
var err error r, err := opts.builder.Build(opts.target, ccr, opts.bOpts)
// We need to hold the lock here while we assign to the ccr.resolver field
// to guard against a data race caused by the following code path,
// rb.Build-->ccr.ReportError-->ccr.poll-->ccr.resolveNow, would end up
// accessing ccr.resolver which is being assigned here.
ccr.resolverMu.Lock()
defer ccr.resolverMu.Unlock()
ccr.resolver, err = opts.builder.Build(opts.target, ccr, opts.bOpts)
if err != nil { if err != nil {
cancel()
return nil, err return nil, err
} }
ccr.resolver = r
return ccr, nil return ccr, nil
} }
func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) { func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) {
ccr.resolverMu.Lock() ccr.serializer.Schedule(func(_ context.Context) {
if !ccr.done.HasFired() {
ccr.resolver.ResolveNow(o) ccr.resolver.ResolveNow(o)
} })
ccr.resolverMu.Unlock()
} }
func (ccr *ccResolverWrapper) close() { func (ccr *ccResolverWrapper) close() {
ccr.resolverMu.Lock() // Close the serializer to ensure that no more calls from the resolver are
// handled, before closing the resolver.
ccr.serializerCancel()
<-ccr.serializer.Done
ccr.resolver.Close() ccr.resolver.Close()
ccr.done.Fire()
ccr.resolverMu.Unlock()
} }
// UpdateState is called by resolver implementations to report new state to gRPC // UpdateState is called by resolver implementations to report new state to gRPC
// which includes addresses and service config. // which includes addresses and service config.
func (ccr *ccResolverWrapper) UpdateState(s resolver.State) error { func (ccr *ccResolverWrapper) UpdateState(s resolver.State) error {
ccr.incomingMu.Lock() errCh := make(chan error, 1)
defer ccr.incomingMu.Unlock() ccr.serializer.Schedule(func(_ context.Context) {
if ccr.done.HasFired() { ccr.addChannelzTraceEvent(s)
return nil
}
ccr.addChannelzTraceEventLocked(s)
ccr.curState = s ccr.curState = s
if err := ccr.cc.updateResolverState(ccr.curState, nil); err == balancer.ErrBadResolverState { if err := ccr.cc.updateResolverState(ccr.curState, nil); err == balancer.ErrBadResolverState {
return balancer.ErrBadResolverState errCh <- balancer.ErrBadResolverState
return
} }
errCh <- nil
})
// If the resolver wrapper is closed when waiting for this state update to
// be handled, the callback serializer will be closed as well, and we can
// rely on its Done channel to ensure that we don't block here forever.
select {
case err := <-errCh:
return err
case <-ccr.serializer.Done:
return nil return nil
} }
}
// ReportError is called by resolver implementations to report errors // ReportError is called by resolver implementations to report errors
// encountered during name resolution to gRPC. // encountered during name resolution to gRPC.
func (ccr *ccResolverWrapper) ReportError(err error) { func (ccr *ccResolverWrapper) ReportError(err error) {
ccr.incomingMu.Lock() ccr.serializer.Schedule(func(_ context.Context) {
defer ccr.incomingMu.Unlock()
if ccr.done.HasFired() {
return
}
channelz.Warningf(logger, ccr.channelzID, "ccResolverWrapper: reporting error to cc: %v", err) channelz.Warningf(logger, ccr.channelzID, "ccResolverWrapper: reporting error to cc: %v", err)
ccr.cc.updateResolverState(resolver.State{}, err) ccr.cc.updateResolverState(resolver.State{}, err)
})
} }
// NewAddress is called by the resolver implementation to send addresses to // NewAddress is called by the resolver implementation to send addresses to
// gRPC. // gRPC.
func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) {
ccr.incomingMu.Lock() ccr.serializer.Schedule(func(_ context.Context) {
defer ccr.incomingMu.Unlock() ccr.addChannelzTraceEvent(resolver.State{Addresses: addrs, ServiceConfig: ccr.curState.ServiceConfig})
if ccr.done.HasFired() {
return
}
ccr.addChannelzTraceEventLocked(resolver.State{Addresses: addrs, ServiceConfig: ccr.curState.ServiceConfig})
ccr.curState.Addresses = addrs ccr.curState.Addresses = addrs
ccr.cc.updateResolverState(ccr.curState, nil) ccr.cc.updateResolverState(ccr.curState, nil)
})
} }
// NewServiceConfig is called by the resolver implementation to send service // NewServiceConfig is called by the resolver implementation to send service
// configs to gRPC. // configs to gRPC.
func (ccr *ccResolverWrapper) NewServiceConfig(sc string) { func (ccr *ccResolverWrapper) NewServiceConfig(sc string) {
ccr.incomingMu.Lock() ccr.serializer.Schedule(func(_ context.Context) {
defer ccr.incomingMu.Unlock()
if ccr.done.HasFired() {
return
}
channelz.Infof(logger, ccr.channelzID, "ccResolverWrapper: got new service config: %s", sc) channelz.Infof(logger, ccr.channelzID, "ccResolverWrapper: got new service config: %s", sc)
if ccr.ignoreServiceConfig { if ccr.ignoreServiceConfig {
channelz.Info(logger, ccr.channelzID, "Service config lookups disabled; ignoring config") channelz.Info(logger, ccr.channelzID, "Service config lookups disabled; ignoring config")
@ -160,9 +157,10 @@ func (ccr *ccResolverWrapper) NewServiceConfig(sc string) {
channelz.Warningf(logger, ccr.channelzID, "ccResolverWrapper: error parsing service config: %v", scpr.Err) channelz.Warningf(logger, ccr.channelzID, "ccResolverWrapper: error parsing service config: %v", scpr.Err)
return return
} }
ccr.addChannelzTraceEventLocked(resolver.State{Addresses: ccr.curState.Addresses, ServiceConfig: scpr}) ccr.addChannelzTraceEvent(resolver.State{Addresses: ccr.curState.Addresses, ServiceConfig: scpr})
ccr.curState.ServiceConfig = scpr ccr.curState.ServiceConfig = scpr
ccr.cc.updateResolverState(ccr.curState, nil) ccr.cc.updateResolverState(ccr.curState, nil)
})
} }
// ParseServiceConfig is called by resolver implementations to parse a JSON // ParseServiceConfig is called by resolver implementations to parse a JSON
@ -171,11 +169,9 @@ func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.P
return parseServiceConfig(scJSON) return parseServiceConfig(scJSON)
} }
// addChannelzTraceEventLocked adds a channelz trace event containing the new // addChannelzTraceEvent adds a channelz trace event containing the new
// state received from resolver implementations. // state received from resolver implementations.
// func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) {
// Caller must hold cc.incomingMu.
func (ccr *ccResolverWrapper) addChannelzTraceEventLocked(s resolver.State) {
var updates []string var updates []string
var oldSC, newSC *ServiceConfig var oldSC, newSC *ServiceConfig
var oldOK, newOK bool var oldOK, newOK bool

View File

@ -146,15 +146,18 @@ func testServiceConfigWaitForReadyTD(t *testing.T, e env) {
ch <- sc ch <- sc
// Wait for the new service config to take effect. // Wait for the new service config to take effect.
mc = cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall") ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
for { defer cancel()
if !*mc.WaitForReady { for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
time.Sleep(100 * time.Millisecond) mc = cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall")
mc = cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall") if *mc.WaitForReady {
continue
}
break break
} }
}
if ctx.Err() != nil {
t.Fatalf("Timeout when waiting for service config to take effect")
}
// The following RPCs are expected to become non-fail-fast ones with 1ms deadline. // The following RPCs are expected to become non-fail-fast ones with 1ms deadline.
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded { if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
@ -212,15 +215,17 @@ func testServiceConfigTimeoutTD(t *testing.T, e env) {
ch <- sc ch <- sc
// Wait for the new service config to take effect. // Wait for the new service config to take effect.
ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
mc = cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall") mc = cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall")
for { if *mc.Timeout == time.Nanosecond {
if *mc.Timeout != time.Nanosecond {
time.Sleep(100 * time.Millisecond)
mc = cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall")
continue
}
break break
} }
}
if ctx.Err() != nil {
t.Fatalf("Timeout when waiting for service config to take effect")
}
ctx, cancel = context.WithTimeout(context.Background(), time.Hour) ctx, cancel = context.WithTimeout(context.Background(), time.Hour)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Code(err) != codes.DeadlineExceeded { if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Code(err) != codes.DeadlineExceeded {