From 750abe8f95cd270ab68f7298a2854148d0c33030 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Tue, 8 Dec 2020 13:32:37 -0800 Subject: [PATCH] resolver: allow config selector to return an RPC error (#4082) --- clientconn.go | 4 ++-- internal/resolver/config_selector.go | 8 ++++--- internal/resolver/config_selector_test.go | 21 +++++++++--------- stream.go | 5 ++++- test/resolver_test.go | 27 ++++++++++++++++------- 5 files changed, 41 insertions(+), 24 deletions(-) diff --git a/clientconn.go b/clientconn.go index b35ba94f2..abfd20098 100644 --- a/clientconn.go +++ b/clientconn.go @@ -109,11 +109,11 @@ type defaultConfigSelector struct { sc *ServiceConfig } -func (dcs *defaultConfigSelector) SelectConfig(rpcInfo iresolver.RPCInfo) *iresolver.RPCConfig { +func (dcs *defaultConfigSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RPCConfig, error) { return &iresolver.RPCConfig{ Context: rpcInfo.Context, MethodConfig: getMethodConfig(dcs.sc, rpcInfo.Method), - } + }, nil } // DialContext creates a client connection to the given target. By default, it's diff --git a/internal/resolver/config_selector.go b/internal/resolver/config_selector.go index 5ef9262e5..e69900400 100644 --- a/internal/resolver/config_selector.go +++ b/internal/resolver/config_selector.go @@ -29,8 +29,10 @@ import ( // ConfigSelector controls what configuration to use for every RPC. type ConfigSelector interface { - // Selects the configuration for the RPC. - SelectConfig(RPCInfo) *RPCConfig + // Selects the configuration for the RPC, or terminates it using the error. + // This error will be converted by the gRPC library to a status error with + // code UNKNOWN if it is not returned as a status error. + SelectConfig(RPCInfo) (*RPCConfig, error) } // RPCInfo contains RPC information needed by a ConfigSelector. @@ -86,7 +88,7 @@ func (scs *SafeConfigSelector) UpdateConfigSelector(cs ConfigSelector) { } // SelectConfig defers to the current ConfigSelector in scs. -func (scs *SafeConfigSelector) SelectConfig(r RPCInfo) *RPCConfig { +func (scs *SafeConfigSelector) SelectConfig(r RPCInfo) (*RPCConfig, error) { scs.mu.RLock() defer scs.mu.RUnlock() return scs.cs.SelectConfig(r) diff --git a/internal/resolver/config_selector_test.go b/internal/resolver/config_selector_test.go index f41b8eb0e..e5a50995d 100644 --- a/internal/resolver/config_selector_test.go +++ b/internal/resolver/config_selector_test.go @@ -36,10 +36,10 @@ func Test(t *testing.T) { } type fakeConfigSelector struct { - selectConfig func(RPCInfo) *RPCConfig + selectConfig func(RPCInfo) (*RPCConfig, error) } -func (f *fakeConfigSelector) SelectConfig(r RPCInfo) *RPCConfig { +func (f *fakeConfigSelector) SelectConfig(r RPCInfo) (*RPCConfig, error) { return f.selectConfig(r) } @@ -59,21 +59,21 @@ func (s) TestSafeConfigSelector(t *testing.T) { cs2Called := make(chan struct{}) cs1 := &fakeConfigSelector{ - selectConfig: func(r RPCInfo) *RPCConfig { + selectConfig: func(r RPCInfo) (*RPCConfig, error) { cs1Called <- struct{}{} if diff := cmp.Diff(r, testRPCInfo); diff != "" { t.Errorf("SelectConfig(%v) called; want %v\n Diffs:\n%s", r, testRPCInfo, diff) } - return <-retChan1 + return <-retChan1, nil }, } cs2 := &fakeConfigSelector{ - selectConfig: func(r RPCInfo) *RPCConfig { + selectConfig: func(r RPCInfo) (*RPCConfig, error) { cs2Called <- struct{}{} if diff := cmp.Diff(r, testRPCInfo); diff != "" { t.Errorf("SelectConfig(%v) called; want %v\n Diffs:\n%s", r, testRPCInfo, diff) } - return <-retChan2 + return <-retChan2, nil }, } @@ -82,9 +82,9 @@ func (s) TestSafeConfigSelector(t *testing.T) { cs1Returned := make(chan struct{}) go func() { - got := scs.SelectConfig(testRPCInfo) // blocks until send to retChan1 - if got != resp1 { - t.Errorf("SelectConfig(%v) = %v; want %v", testRPCInfo, got, resp1) + got, err := scs.SelectConfig(testRPCInfo) // blocks until send to retChan1 + if err != nil || got != resp1 { + t.Errorf("SelectConfig(%v) = %v, %v; want %v, nil", testRPCInfo, got, err, resp1) } close(cs1Returned) }() @@ -112,7 +112,8 @@ func (s) TestSafeConfigSelector(t *testing.T) { for dl := time.Now().Add(150 * time.Millisecond); !time.Now().After(dl); { gotConfigChan := make(chan *RPCConfig) go func() { - gotConfigChan <- scs.SelectConfig(testRPCInfo) + cfg, _ := scs.SelectConfig(testRPCInfo) + gotConfigChan <- cfg }() select { case <-time.After(500 * time.Millisecond): diff --git a/stream.go b/stream.go index 8d4694a33..eda1248d6 100644 --- a/stream.go +++ b/stream.go @@ -175,7 +175,10 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth var mc serviceconfig.MethodConfig var onCommit func() - rpcConfig := cc.safeConfigSelector.SelectConfig(iresolver.RPCInfo{Context: ctx, Method: method}) + rpcConfig, err := cc.safeConfigSelector.SelectConfig(iresolver.RPCInfo{Context: ctx, Method: method}) + if err != nil { + return nil, status.Convert(err).Err() + } if rpcConfig != nil { if rpcConfig.Context != nil { ctx = rpcConfig.Context diff --git a/test/resolver_test.go b/test/resolver_test.go index ab154f0ea..648245aef 100644 --- a/test/resolver_test.go +++ b/test/resolver_test.go @@ -20,11 +20,13 @@ package test import ( "context" + "fmt" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/codes" iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/internal/stubserver" @@ -32,14 +34,15 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" ) type funcConfigSelector struct { - f func(iresolver.RPCInfo) *iresolver.RPCConfig + f func(iresolver.RPCInfo) (*iresolver.RPCConfig, error) } -func (f funcConfigSelector) SelectConfig(i iresolver.RPCInfo) *iresolver.RPCConfig { +func (f funcConfigSelector) SelectConfig(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) { return f.f(i) } @@ -75,12 +78,14 @@ func (s) TestConfigSelector(t *testing.T) { testCases := []struct { name string - md metadata.MD - config *iresolver.RPCConfig + md metadata.MD // MD sent with RPC + config *iresolver.RPCConfig // config returned by config selector + csErr error // error returned by config selector wantMD metadata.MD wantDeadline time.Time wantTimeout time.Duration + wantErr error }{{ name: "basic", md: testMD, @@ -95,6 +100,10 @@ func (s) TestConfigSelector(t *testing.T) { }, wantMD: mdOut, wantDeadline: ctxDeadline, + }, { + name: "erroring SelectConfig", + csErr: status.Errorf(codes.Unavailable, "cannot send RPC"), + wantErr: status.Errorf(codes.Unavailable, "cannot send RPC"), }, { name: "alter timeout; remove MD", md: testMD, @@ -138,13 +147,13 @@ func (s) TestConfigSelector(t *testing.T) { Addresses: []resolver.Address{{Addr: ss.Address}}, ServiceConfig: parseCfg(ss.R, "{}"), }, funcConfigSelector{ - f: func(i iresolver.RPCInfo) *iresolver.RPCConfig { + f: func(i iresolver.RPCInfo) (*iresolver.RPCConfig, error) { gotInfo = &i cfg := tc.config if cfg != nil && cfg.Context == nil { cfg.Context = i.Context } - return cfg + return cfg, tc.csErr }, }) ss.R.UpdateState(state) // Blocks until config selector is applied @@ -152,8 +161,10 @@ func (s) TestConfigSelector(t *testing.T) { onCommittedCalled = false ctx := metadata.NewOutgoingContext(ctx, tc.md) startTime := time.Now() - if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { - t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, nil", err) + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); fmt.Sprint(err) != fmt.Sprint(tc.wantErr) { + t.Fatalf("client.EmptyCall(_, _) = _, %v; want _, %v", err, tc.wantErr) + } else if err != nil { + return // remaining checks are invalid } if gotInfo == nil {