diff --git a/dialoptions.go b/dialoptions.go index b7524f826..9f872df8b 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -69,6 +69,10 @@ type dialOptions struct { minConnectTimeout func() time.Duration defaultServiceConfig *ServiceConfig // defaultServiceConfig is parsed from defaultServiceConfigRawJSON. defaultServiceConfigRawJSON *string + // This is used by ccResolverWrapper to backoff between successive calls to + // resolver.ResolveNow(). The user will have no need to configure this, but + // we need to be able to configure this in tests. + resolveNowBackoff func(int) time.Duration } // DialOption configures how we set up the connection. @@ -559,6 +563,7 @@ func defaultDialOptions() dialOptions { WriteBufferSize: defaultWriteBufSize, ReadBufferSize: defaultReadBufSize, }, + resolveNowBackoff: internalbackoff.DefaultExponential.Backoff, } } @@ -572,3 +577,13 @@ func withMinConnectDeadline(f func() time.Duration) DialOption { o.minConnectTimeout = f }) } + +// withResolveNowBackoff specifies the function that clientconn uses to backoff +// between successive calls to resolver.ResolveNow(). +// +// For testing purpose only. +func withResolveNowBackoff(f func(int) time.Duration) DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.resolveNowBackoff = f + }) +} diff --git a/resolver_conn_wrapper.go b/resolver_conn_wrapper.go index b4f2a5714..33198007b 100644 --- a/resolver_conn_wrapper.go +++ b/resolver_conn_wrapper.go @@ -24,27 +24,25 @@ import ( "sync" "time" - "google.golang.org/grpc/backoff" "google.golang.org/grpc/balancer" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" - - internalbackoff "google.golang.org/grpc/internal/backoff" ) // ccResolverWrapper is a wrapper on top of cc for resolvers. // It implements resolver.ClientConnection interface. type ccResolverWrapper struct { - cc *ClientConn - resolver resolver.Resolver - done *grpcsync.Event - curState resolver.State + cc *ClientConn + resolverMu sync.Mutex + resolver resolver.Resolver + done *grpcsync.Event + curState resolver.State - mu sync.Mutex // protects polling - polling chan struct{} + pollingMu sync.Mutex + polling chan struct{} } // split2 returns the values from strings.SplitN(s, sep, 2). @@ -93,35 +91,39 @@ func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { } var err error + // We need to hold the lock here while we assign to the ccr.resolver field + // to guard against a data race caused by the following code path, + // rb.Build-->ccr.ReportError-->ccr.poll-->ccr.resolveNow, would end up + // accessing ccr.resolver which is being assigned here. + ccr.resolverMu.Lock() ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{DisableServiceConfig: cc.dopts.disableServiceConfig}) if err != nil { return nil, err } + ccr.resolverMu.Unlock() return ccr, nil } func (ccr *ccResolverWrapper) resolveNow(o resolver.ResolveNowOption) { - ccr.mu.Lock() + ccr.resolverMu.Lock() if !ccr.done.HasFired() { ccr.resolver.ResolveNow(o) } - ccr.mu.Unlock() + ccr.resolverMu.Unlock() } func (ccr *ccResolverWrapper) close() { - ccr.mu.Lock() + ccr.resolverMu.Lock() ccr.resolver.Close() ccr.done.Fire() - ccr.mu.Unlock() + ccr.resolverMu.Unlock() } -var resolverBackoff = internalbackoff.Exponential{Config: backoff.Config{MaxDelay: 2 * time.Minute}}.Backoff - // poll begins or ends asynchronous polling of the resolver based on whether // err is ErrBadResolverState. func (ccr *ccResolverWrapper) poll(err error) { - ccr.mu.Lock() - defer ccr.mu.Unlock() + ccr.pollingMu.Lock() + defer ccr.pollingMu.Unlock() if err != balancer.ErrBadResolverState { // stop polling if ccr.polling != nil { @@ -139,7 +141,7 @@ func (ccr *ccResolverWrapper) poll(err error) { go func() { for i := 0; ; i++ { ccr.resolveNow(resolver.ResolveNowOption{}) - t := time.NewTimer(resolverBackoff(i)) + t := time.NewTimer(ccr.cc.dopts.resolveNowBackoff(i)) select { case <-p: t.Stop() diff --git a/resolver_conn_wrapper_test.go b/resolver_conn_wrapper_test.go index 60e7d0d7a..5f78801f7 100644 --- a/resolver_conn_wrapper_test.go +++ b/resolver_conn_wrapper_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/status" ) @@ -122,10 +123,8 @@ func (s) TestDialParseTargetUnknownScheme(t *testing.T) { } func testResolverErrorPolling(t *testing.T, badUpdate func(*manual.Resolver), goodUpdate func(*manual.Resolver), dopts ...DialOption) { - defer func(o func(int) time.Duration) { resolverBackoff = o }(resolverBackoff) - boIter := make(chan int) - resolverBackoff = func(v int) time.Duration { + resolverBackoff := func(v int) time.Duration { boIter <- v return 0 } @@ -136,7 +135,11 @@ func testResolverErrorPolling(t *testing.T, badUpdate func(*manual.Resolver), go defer func() { close(rn) }() r.ResolveNowCallback = func(resolver.ResolveNowOption) { rn <- struct{}{} } - cc, err := Dial(r.Scheme()+":///test.server", append([]DialOption{WithInsecure()}, dopts...)...) + defaultDialOptions := []DialOption{ + WithInsecure(), + withResolveNowBackoff(resolverBackoff), + } + cc, err := Dial(r.Scheme()+":///test.server", append(defaultDialOptions, dopts...)...) if err != nil { t.Fatalf("Dial(_, _) = _, %v; want _, nil", err) } @@ -202,6 +205,31 @@ func (s) TestServiceConfigErrorPolling(t *testing.T) { }) } +// TestResolverErrorInBuild makes the resolver.Builder call into the ClientConn +// during the Build call. We use two separate mutexes in the code which make +// sure there is no data race in this code path, and also that there is no +// deadlock. +func (s) TestResolverErrorInBuild(t *testing.T) { + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + r.InitialState(resolver.State{ServiceConfig: &serviceconfig.ParseResult{Err: errors.New("resolver build err")}}) + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure()) + if err != nil { + t.Fatalf("Dial(_, _) = _, %v; want _, nil", err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var dummy int + const wantMsg = "error parsing service config" + const wantCode = codes.Unavailable + if err := cc.Invoke(ctx, "/foo/bar", &dummy, &dummy); status.Code(err) != wantCode || !strings.Contains(status.Convert(err).Message(), wantMsg) { + t.Fatalf("cc.Invoke(_, _, _, _) = %v; want status.Code()==%v, status.Message() contains %q", err, wantCode, wantMsg) + } +} + func (s) TestServiceConfigErrorRPC(t *testing.T) { r, rcleanup := manual.GenerateAndRegisterManualResolver() defer rcleanup()