diff --git a/internal/grpcsync/callback_serializer.go b/internal/grpcsync/callback_serializer.go index 6df798c00..d91f92463 100644 --- a/internal/grpcsync/callback_serializer.go +++ b/internal/grpcsync/callback_serializer.go @@ -31,6 +31,12 @@ import ( // // This type is safe for concurrent access. type CallbackSerializer struct { + // Done is closed once the serializer is shut down completely, i.e a + // scheduled callback, if any, that was running when the context passed to + // NewCallbackSerializer is cancelled, has completed and the serializer has + // deallocated all its resources. + Done chan struct{} + callbacks *buffer.Unbounded } @@ -39,7 +45,10 @@ type CallbackSerializer struct { // provided context to shutdown the CallbackSerializer. It is guaranteed that no // callbacks will be executed once this context is canceled. func NewCallbackSerializer(ctx context.Context) *CallbackSerializer { - t := &CallbackSerializer{callbacks: buffer.NewUnbounded()} + t := &CallbackSerializer{ + Done: make(chan struct{}), + callbacks: buffer.NewUnbounded(), + } go t.run(ctx) return t } @@ -53,6 +62,7 @@ func (t *CallbackSerializer) Schedule(f func(ctx context.Context)) { } func (t *CallbackSerializer) run(ctx context.Context) { + defer close(t.Done) for ctx.Err() == nil { select { case <-ctx.Done(): diff --git a/internal/grpcsync/callback_serializer_test.go b/internal/grpcsync/callback_serializer_test.go index 6cb1ee52d..8c465af66 100644 --- a/internal/grpcsync/callback_serializer_test.go +++ b/internal/grpcsync/callback_serializer_test.go @@ -144,19 +144,13 @@ func (s) TestCallbackSerializer_Schedule_Close(t *testing.T) { cs := NewCallbackSerializer(ctx) // Schedule a callback which blocks until the context passed to it is - // canceled. It also closes a couple of channels to signal that it started - // and finished respectively. + // canceled. It also closes a channel to signal that it has started. firstCallbackStartedCh := make(chan struct{}) - firstCallbackFinishCh := make(chan struct{}) cs.Schedule(func(ctx context.Context) { close(firstCallbackStartedCh) <-ctx.Done() - close(firstCallbackFinishCh) }) - // Wait for the first callback to start before scheduling the others. - <-firstCallbackStartedCh - // Schedule a bunch of callbacks. These should not be exeuted since the first // one started earlier is blocked. const numCallbacks = 10 @@ -174,11 +168,14 @@ func (s) TestCallbackSerializer_Schedule_Close(t *testing.T) { t.Fatal(err) } + // Wait for the first callback to start before closing the scheduler. + <-firstCallbackStartedCh + // Cancel the context which will unblock the first callback. None of the // other callbacks (which have not started executing at this point) should // be executed after this. cancel() - <-firstCallbackFinishCh + <-cs.Done // Ensure that the newer callbacks are not executed. select { diff --git a/resolver_conn_wrapper.go b/resolver_conn_wrapper.go index 854e90f69..ce12b52ec 100644 --- a/resolver_conn_wrapper.go +++ b/resolver_conn_wrapper.go @@ -19,8 +19,8 @@ package grpc import ( + "context" "strings" - "sync" "google.golang.org/grpc/balancer" "google.golang.org/grpc/internal/channelz" @@ -42,15 +42,17 @@ type ccResolverWrapper struct { // The following fields are initialized when the wrapper is created and are // read-only afterwards, and therefore can be accessed without a mutex. cc resolverStateUpdater - done *grpcsync.Event channelzID *channelz.Identifier ignoreServiceConfig bool - resolverMu sync.Mutex - resolver resolver.Resolver - - incomingMu sync.Mutex // Synchronizes all the incoming calls. - curState resolver.State + // Outgoing (gRPC --> resolver) and incoming (resolver --> gRPC) calls are + // guaranteed to execute in a mutually exclusive manner as they are + // scheduled on the CallbackSerializer. Fields accessed *only* in serializer + // callbacks, can therefore be accessed without a mutex. + serializer *grpcsync.CallbackSerializer + serializerCancel context.CancelFunc + resolver resolver.Resolver + curState resolver.State } // ccResolverWrapperOpts wraps the arguments to be passed when creating a new @@ -65,104 +67,100 @@ type ccResolverWrapperOpts struct { // newCCResolverWrapper uses the resolver.Builder to build a Resolver and // returns a ccResolverWrapper object which wraps the newly built resolver. func newCCResolverWrapper(cc resolverStateUpdater, opts ccResolverWrapperOpts) (*ccResolverWrapper, error) { + ctx, cancel := context.WithCancel(context.Background()) ccr := &ccResolverWrapper{ cc: cc, - done: grpcsync.NewEvent(), channelzID: opts.channelzID, ignoreServiceConfig: opts.bOpts.DisableServiceConfig, + serializer: grpcsync.NewCallbackSerializer(ctx), + serializerCancel: cancel, } - var err error - // We need to hold the lock here while we assign to the ccr.resolver field - // to guard against a data race caused by the following code path, - // rb.Build-->ccr.ReportError-->ccr.poll-->ccr.resolveNow, would end up - // accessing ccr.resolver which is being assigned here. - ccr.resolverMu.Lock() - defer ccr.resolverMu.Unlock() - ccr.resolver, err = opts.builder.Build(opts.target, ccr, opts.bOpts) + r, err := opts.builder.Build(opts.target, ccr, opts.bOpts) if err != nil { + cancel() return nil, err } + ccr.resolver = r return ccr, nil } func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOptions) { - ccr.resolverMu.Lock() - if !ccr.done.HasFired() { + ccr.serializer.Schedule(func(_ context.Context) { ccr.resolver.ResolveNow(o) - } - ccr.resolverMu.Unlock() + }) } func (ccr *ccResolverWrapper) close() { - ccr.resolverMu.Lock() + // Close the serializer to ensure that no more calls from the resolver are + // handled, before closing the resolver. + ccr.serializerCancel() + <-ccr.serializer.Done ccr.resolver.Close() - ccr.done.Fire() - ccr.resolverMu.Unlock() } // UpdateState is called by resolver implementations to report new state to gRPC // which includes addresses and service config. func (ccr *ccResolverWrapper) UpdateState(s resolver.State) error { - ccr.incomingMu.Lock() - defer ccr.incomingMu.Unlock() - if ccr.done.HasFired() { + errCh := make(chan error, 1) + ccr.serializer.Schedule(func(_ context.Context) { + ccr.addChannelzTraceEvent(s) + ccr.curState = s + if err := ccr.cc.updateResolverState(ccr.curState, nil); err == balancer.ErrBadResolverState { + errCh <- balancer.ErrBadResolverState + return + } + errCh <- nil + }) + + // If the resolver wrapper is closed when waiting for this state update to + // be handled, the callback serializer will be closed as well, and we can + // rely on its Done channel to ensure that we don't block here forever. + select { + case err := <-errCh: + return err + case <-ccr.serializer.Done: return nil } - ccr.addChannelzTraceEventLocked(s) - ccr.curState = s - if err := ccr.cc.updateResolverState(ccr.curState, nil); err == balancer.ErrBadResolverState { - return balancer.ErrBadResolverState - } - return nil } // ReportError is called by resolver implementations to report errors // encountered during name resolution to gRPC. func (ccr *ccResolverWrapper) ReportError(err error) { - ccr.incomingMu.Lock() - defer ccr.incomingMu.Unlock() - if ccr.done.HasFired() { - return - } - channelz.Warningf(logger, ccr.channelzID, "ccResolverWrapper: reporting error to cc: %v", err) - ccr.cc.updateResolverState(resolver.State{}, err) + ccr.serializer.Schedule(func(_ context.Context) { + channelz.Warningf(logger, ccr.channelzID, "ccResolverWrapper: reporting error to cc: %v", err) + ccr.cc.updateResolverState(resolver.State{}, err) + }) } // NewAddress is called by the resolver implementation to send addresses to // gRPC. func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { - ccr.incomingMu.Lock() - defer ccr.incomingMu.Unlock() - if ccr.done.HasFired() { - return - } - ccr.addChannelzTraceEventLocked(resolver.State{Addresses: addrs, ServiceConfig: ccr.curState.ServiceConfig}) - ccr.curState.Addresses = addrs - ccr.cc.updateResolverState(ccr.curState, nil) + ccr.serializer.Schedule(func(_ context.Context) { + ccr.addChannelzTraceEvent(resolver.State{Addresses: addrs, ServiceConfig: ccr.curState.ServiceConfig}) + ccr.curState.Addresses = addrs + ccr.cc.updateResolverState(ccr.curState, nil) + }) } // NewServiceConfig is called by the resolver implementation to send service // configs to gRPC. func (ccr *ccResolverWrapper) NewServiceConfig(sc string) { - ccr.incomingMu.Lock() - defer ccr.incomingMu.Unlock() - if ccr.done.HasFired() { - return - } - channelz.Infof(logger, ccr.channelzID, "ccResolverWrapper: got new service config: %s", sc) - if ccr.ignoreServiceConfig { - channelz.Info(logger, ccr.channelzID, "Service config lookups disabled; ignoring config") - return - } - scpr := parseServiceConfig(sc) - if scpr.Err != nil { - channelz.Warningf(logger, ccr.channelzID, "ccResolverWrapper: error parsing service config: %v", scpr.Err) - return - } - ccr.addChannelzTraceEventLocked(resolver.State{Addresses: ccr.curState.Addresses, ServiceConfig: scpr}) - ccr.curState.ServiceConfig = scpr - ccr.cc.updateResolverState(ccr.curState, nil) + ccr.serializer.Schedule(func(_ context.Context) { + channelz.Infof(logger, ccr.channelzID, "ccResolverWrapper: got new service config: %s", sc) + if ccr.ignoreServiceConfig { + channelz.Info(logger, ccr.channelzID, "Service config lookups disabled; ignoring config") + return + } + scpr := parseServiceConfig(sc) + if scpr.Err != nil { + channelz.Warningf(logger, ccr.channelzID, "ccResolverWrapper: error parsing service config: %v", scpr.Err) + return + } + ccr.addChannelzTraceEvent(resolver.State{Addresses: ccr.curState.Addresses, ServiceConfig: scpr}) + ccr.curState.ServiceConfig = scpr + ccr.cc.updateResolverState(ccr.curState, nil) + }) } // ParseServiceConfig is called by resolver implementations to parse a JSON @@ -171,11 +169,9 @@ func (ccr *ccResolverWrapper) ParseServiceConfig(scJSON string) *serviceconfig.P return parseServiceConfig(scJSON) } -// addChannelzTraceEventLocked adds a channelz trace event containing the new +// addChannelzTraceEvent adds a channelz trace event containing the new // state received from resolver implementations. -// -// Caller must hold cc.incomingMu. -func (ccr *ccResolverWrapper) addChannelzTraceEventLocked(s resolver.State) { +func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) { var updates []string var oldSC, newSC *ServiceConfig var oldOK, newOK bool diff --git a/test/service_config_deprecated_test.go b/test/service_config_deprecated_test.go index 035f11526..ecf43a576 100644 --- a/test/service_config_deprecated_test.go +++ b/test/service_config_deprecated_test.go @@ -146,15 +146,18 @@ func testServiceConfigWaitForReadyTD(t *testing.T, e env) { ch <- sc // Wait for the new service config to take effect. - mc = cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall") - for { - if !*mc.WaitForReady { - time.Sleep(100 * time.Millisecond) - mc = cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall") - continue + ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) { + mc = cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall") + if *mc.WaitForReady { + break } - break } + if ctx.Err() != nil { + t.Fatalf("Timeout when waiting for service config to take effect") + } + // The following RPCs are expected to become non-fail-fast ones with 1ms deadline. if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) @@ -212,14 +215,16 @@ func testServiceConfigTimeoutTD(t *testing.T, e env) { ch <- sc // Wait for the new service config to take effect. - mc = cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall") - for { - if *mc.Timeout != time.Nanosecond { - time.Sleep(100 * time.Millisecond) - mc = cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall") - continue + ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) { + mc = cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall") + if *mc.Timeout == time.Nanosecond { + break } - break + } + if ctx.Err() != nil { + t.Fatalf("Timeout when waiting for service config to take effect") } ctx, cancel = context.WithTimeout(context.Background(), time.Hour)