diff --git a/clientconn.go b/clientconn.go index d607d4e9e..e3919895e 100644 --- a/clientconn.go +++ b/clientconn.go @@ -133,6 +133,10 @@ func (dcs *defaultConfigSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*ires // https://github.com/grpc/grpc/blob/master/doc/naming.md. // e.g. to use dns resolver, a "dns:///" prefix should be applied to the target. func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) { + return dialWithGlobalOptions(ctx, target, false, opts...) +} + +func dialWithGlobalOptions(ctx context.Context, target string, disableGlobalOptions bool, opts ...DialOption) (conn *ClientConn, err error) { cc := &ClientConn{ target: target, csMgr: &connectivityStateManager{}, @@ -146,8 +150,10 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * cc.safeConfigSelector.UpdateConfigSelector(&defaultConfigSelector{nil}) cc.ctx, cc.cancel = context.WithCancel(context.Background()) - for _, opt := range extraDialOptions { - opt.apply(&cc.dopts) + if !disableGlobalOptions { + for _, opt := range globalDialOptions { + opt.apply(&cc.dopts) + } } for _, opt := range opts { diff --git a/default_dial_option_server_option_test.go b/default_dial_option_server_option_test.go index c6cdd7c84..b1501d2fb 100644 --- a/default_dial_option_server_option_test.go +++ b/default_dial_option_server_option_test.go @@ -19,6 +19,7 @@ package grpc import ( + "context" "strings" "testing" @@ -26,7 +27,7 @@ import ( "google.golang.org/grpc/internal" ) -func (s) TestAddExtraDialOptions(t *testing.T) { +func (s) TestAddGlobalDialOptions(t *testing.T) { // Ensure the Dial fails without credentials if _, err := Dial("fake"); err == nil { t.Fatalf("Dialing without a credential did not fail") @@ -40,8 +41,8 @@ func (s) TestAddExtraDialOptions(t *testing.T) { opts := []DialOption{WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials())} internal.AddGlobalDialOptions.(func(opt ...DialOption))(opts...) for i, opt := range opts { - if extraDialOptions[i] != opt { - t.Fatalf("Unexpected extra dial option at index %d: %v != %v", i, extraDialOptions[i], opt) + if globalDialOptions[i] != opt { + t.Fatalf("Unexpected global dial option at index %d: %v != %v", i, globalDialOptions[i], opt) } } @@ -53,19 +54,37 @@ func (s) TestAddExtraDialOptions(t *testing.T) { } internal.ClearGlobalDialOptions() - if len(extraDialOptions) != 0 { - t.Fatalf("Unexpected len of extraDialOptions: %d != 0", len(extraDialOptions)) + if len(globalDialOptions) != 0 { + t.Fatalf("Unexpected len of globalDialOptions: %d != 0", len(globalDialOptions)) } } -func (s) TestAddExtraServerOptions(t *testing.T) { +// TestDisableGlobalOptions tests dialing with a bit that disables global +// options. Dialing with this bit set should not pick up global options. +func (s) TestDisableGlobalOptions(t *testing.T) { + // Set transport credentials as a global option. + internal.AddGlobalDialOptions.(func(opt ...DialOption))(WithTransportCredentials(insecure.NewCredentials())) + // Dial with disable global options set to true. This Dial should fail due + // to the global dial options with credentials not being picked up due to it + // being disabled. + if _, err := internal.DialWithGlobalOptions.(func(context.Context, string, bool, ...DialOption) (*ClientConn, error))(context.Background(), "fake", true); err == nil { + t.Fatalf("Dialing without a credential did not fail") + } else { + if !strings.Contains(err.Error(), "no transport security set") { + t.Fatalf("Dialing failed with unexpected error: %v", err) + } + } + internal.ClearGlobalDialOptions() +} + +func (s) TestAddGlobalServerOptions(t *testing.T) { const maxRecvSize = 998765 // Set and check the ServerOptions opts := []ServerOption{Creds(insecure.NewCredentials()), MaxRecvMsgSize(maxRecvSize)} internal.AddGlobalServerOptions.(func(opt ...ServerOption))(opts...) for i, opt := range opts { - if extraServerOptions[i] != opt { - t.Fatalf("Unexpected extra server option at index %d: %v != %v", i, extraServerOptions[i], opt) + if globalServerOptions[i] != opt { + t.Fatalf("Unexpected global server option at index %d: %v != %v", i, globalServerOptions[i], opt) } } @@ -76,8 +95,8 @@ func (s) TestAddExtraServerOptions(t *testing.T) { } internal.ClearGlobalServerOptions() - if len(extraServerOptions) != 0 { - t.Fatalf("Unexpected len of extraServerOptions: %d != 0", len(extraServerOptions)) + if len(globalServerOptions) != 0 { + t.Fatalf("Unexpected len of globalServerOptions: %d != 0", len(globalServerOptions)) } } diff --git a/dialoptions.go b/dialoptions.go index 4866da101..67f240496 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -38,11 +38,12 @@ import ( func init() { internal.AddGlobalDialOptions = func(opt ...DialOption) { - extraDialOptions = append(extraDialOptions, opt...) + globalDialOptions = append(globalDialOptions, opt...) } internal.ClearGlobalDialOptions = func() { - extraDialOptions = nil + globalDialOptions = nil } + internal.DialWithGlobalOptions = dialWithGlobalOptions internal.WithBinaryLogger = withBinaryLogger internal.JoinDialOptions = newJoinDialOption } @@ -83,7 +84,7 @@ type DialOption interface { apply(*dialOptions) } -var extraDialOptions []DialOption +var globalDialOptions []DialOption // EmptyDialOption does not alter the dial configuration. It can be embedded in // another structure to build custom dial options. diff --git a/internal/internal.go b/internal/internal.go index 0a76d9de6..cb5139a19 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -29,6 +29,9 @@ import ( ) var ( + // DialWithGlobalOptions dials with a knob on whether to disable global dial + // options (set via AddGlobalDialOptions). + DialWithGlobalOptions interface{} // func (context.Context, string, bool, ...DialOption) (*ClientConn, error) // WithHealthCheckFunc is set by dialoptions.go WithHealthCheckFunc interface{} // func (HealthChecker) DialOption // HealthCheckFunc is used to provide client-side LB channel health checking diff --git a/server.go b/server.go index d5a6e78be..0ebaaf5da 100644 --- a/server.go +++ b/server.go @@ -74,10 +74,10 @@ func init() { srv.drainServerTransports(addr) } internal.AddGlobalServerOptions = func(opt ...ServerOption) { - extraServerOptions = append(extraServerOptions, opt...) + globalServerOptions = append(globalServerOptions, opt...) } internal.ClearGlobalServerOptions = func() { - extraServerOptions = nil + globalServerOptions = nil } internal.BinaryLogger = binaryLogger internal.JoinServerOptions = newJoinServerOption @@ -183,7 +183,7 @@ var defaultServerOptions = serverOptions{ writeBufferSize: defaultWriteBufSize, readBufferSize: defaultReadBufSize, } -var extraServerOptions []ServerOption +var globalServerOptions []ServerOption // A ServerOption sets options such as credentials, codec and keepalive parameters, etc. type ServerOption interface { @@ -600,7 +600,7 @@ func (s *Server) stopServerWorkers() { // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { opts := defaultServerOptions - for _, o := range extraServerOptions { + for _, o := range globalServerOptions { o.apply(&opts) } for _, o := range opt {