dialOption: export WithContextDialer() (#2629)

fixes #2627
This commit is contained in:
Menghan Li 2019-02-25 15:22:10 -08:00 committed by GitHub
parent 871b88ce2e
commit 40cb5618f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 31 additions and 32 deletions

View File

@ -279,10 +279,7 @@ func (lb *lbBalancer) dialRemoteLB(remoteLBName string) {
dopts = append(dopts, grpc.WithInsecure()) dopts = append(dopts, grpc.WithInsecure())
} }
if lb.opt.Dialer != nil { if lb.opt.Dialer != nil {
// WithDialer takes a different type of function, so we instead use a dopts = append(dopts, grpc.WithContextDialer(lb.opt.Dialer))
// special DialOption here.
wcd := internal.WithContextDialer.(func(func(context.Context, string) (net.Conn, error)) grpc.DialOption)
dopts = append(dopts, wcd(lb.opt.Dialer))
} }
// Explicitly set pickfirst as the balancer. // Explicitly set pickfirst as the balancer.
dopts = append(dopts, grpc.WithBalancerName(grpc.PickFirstBalancerName)) dopts = append(dopts, grpc.WithBalancerName(grpc.PickFirstBalancerName))

View File

@ -117,9 +117,9 @@ func (c *serverNameCheckCreds) OverrideServerName(s string) error {
// fakeNameDialer replaces fakeName with localhost when dialing. // fakeNameDialer replaces fakeName with localhost when dialing.
// This will test that custom dialer is passed from Dial to grpclb. // This will test that custom dialer is passed from Dial to grpclb.
func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) { func fakeNameDialer(ctx context.Context, addr string) (net.Conn, error) {
addr = strings.Replace(addr, fakeName, "localhost", 1) addr = strings.Replace(addr, fakeName, "localhost", 1)
return net.DialTimeout("tcp", addr, timeout) return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
} }
// merge merges the new client stats into current stats. // merge merges the new client stats into current stats.
@ -382,7 +382,7 @@ func TestGRPCLB(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the backend %v", err) t.Fatalf("Failed to dial to the backend %v", err)
} }
@ -433,7 +433,7 @@ func TestGRPCLBWeighted(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the backend %v", err) t.Fatalf("Failed to dial to the backend %v", err)
} }
@ -498,7 +498,7 @@ func TestDropRequest(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the backend %v", err) t.Fatalf("Failed to dial to the backend %v", err)
} }
@ -585,7 +585,7 @@ func TestBalancerDisconnects(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the backend %v", err) t.Fatalf("Failed to dial to the backend %v", err)
} }
@ -666,7 +666,7 @@ func TestFallback(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the backend %v", err) t.Fatalf("Failed to dial to the backend %v", err)
} }
@ -760,7 +760,7 @@ func runAndGetStats(t *testing.T, drop bool, runRPCs func(*grpc.ClientConn)) *rp
cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
grpc.WithTransportCredentials(&creds), grpc.WithTransportCredentials(&creds),
grpc.WithPerRPCCredentials(failPreRPCCred{}), grpc.WithPerRPCCredentials(failPreRPCCred{}),
grpc.WithDialer(fakeNameDialer)) grpc.WithContextDialer(fakeNameDialer))
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the backend %v", err) t.Fatalf("Failed to dial to the backend %v", err)
} }

View File

@ -194,11 +194,10 @@ func makeClient(benchFeatures stats.Features) (testpb.BenchmarkServiceClient, fu
if *useBufconn { if *useBufconn {
bcLis := bufconn.Listen(256 * 1024) bcLis := bufconn.Listen(256 * 1024)
lis = bcLis lis = bcLis
opts = append(opts, grpc.WithDialer(func(string, time.Duration) (net.Conn, error) { opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, address string) (net.Conn, error) {
return nw.TimeoutDialer( return nw.ContextDialer(func(context.Context, string, string) (net.Conn, error) {
func(string, string, time.Duration) (net.Conn, error) { return bcLis.Dial()
return bcLis.Dial() })(ctx, "", "")
})("", "", 0)
})) }))
} else { } else {
var err error var err error
@ -206,8 +205,8 @@ func makeClient(benchFeatures stats.Features) (testpb.BenchmarkServiceClient, fu
if err != nil { if err != nil {
grpclog.Fatalf("Failed to listen: %v", err) grpclog.Fatalf("Failed to listen: %v", err)
} }
opts = append(opts, grpc.WithDialer(func(_ string, timeout time.Duration) (net.Conn, error) { opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, address string) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", lis.Addr().String(), timeout) return nw.ContextDialer((&net.Dialer{}).DialContext)(ctx, "tcp", lis.Addr().String())
})) }))
} }
lis = nw.Listener(lis) lis = nw.Listener(lis)

View File

@ -320,8 +320,8 @@ func runUnary(b *testing.B, benchFeatures stats.Features) {
defer stopper() defer stopper()
conn := NewClientConn( conn := NewClientConn(
target, grpc.WithInsecure(), target, grpc.WithInsecure(),
grpc.WithDialer(func(address string, timeout time.Duration) (net.Conn, error) { grpc.WithContextDialer(func(ctx context.Context, address string) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", address, timeout) return nw.ContextDialer((&net.Dialer{}).DialContext)(ctx, "tcp", address)
}), }),
) )
tc := testpb.NewBenchmarkServiceClient(conn) tc := testpb.NewBenchmarkServiceClient(conn)
@ -374,8 +374,8 @@ func runStream(b *testing.B, benchFeatures stats.Features) {
defer stopper() defer stopper()
conn := NewClientConn( conn := NewClientConn(
target, grpc.WithInsecure(), target, grpc.WithInsecure(),
grpc.WithDialer(func(address string, timeout time.Duration) (net.Conn, error) { grpc.WithContextDialer(func(ctx context.Context, address string) (net.Conn, error) {
return nw.TimeoutDialer(net.DialTimeout)("tcp", address, timeout) return nw.ContextDialer((&net.Dialer{}).DialContext)(ctx, "tcp", address)
}), }),
) )
tc := testpb.NewBenchmarkServiceClient(conn) tc := testpb.NewBenchmarkServiceClient(conn)

View File

@ -329,14 +329,17 @@ func WithTimeout(d time.Duration) DialOption {
}) })
} }
func withContextDialer(f func(context.Context, string) (net.Conn, error)) DialOption { // WithContextDialer returns a DialOption that sets a dialer to create
// connections. If FailOnNonTempDialError() is set to true, and an error is
// returned by f, gRPC checks the error's Temporary() method to decide if it
// should try to reconnect to the network address.
func WithContextDialer(f func(context.Context, string) (net.Conn, error)) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.copts.Dialer = f o.copts.Dialer = f
}) })
} }
func init() { func init() {
internal.WithContextDialer = withContextDialer
internal.WithResolverBuilder = withResolverBuilder internal.WithResolverBuilder = withResolverBuilder
internal.WithHealthCheckFunc = withHealthCheckFunc internal.WithHealthCheckFunc = withHealthCheckFunc
} }
@ -345,8 +348,10 @@ func init() {
// network addresses. If FailOnNonTempDialError() is set to true, and an error // network addresses. If FailOnNonTempDialError() is set to true, and an error
// is returned by f, gRPC checks the error's Temporary() method to decide if it // is returned by f, gRPC checks the error's Temporary() method to decide if it
// should try to reconnect to the network address. // should try to reconnect to the network address.
//
// Deprecated: use WithContextDialer instead
func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return withContextDialer( return WithContextDialer(
func(ctx context.Context, addr string) (net.Conn, error) { func(ctx context.Context, addr string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
return f(addr, time.Until(deadline)) return f(addr, time.Until(deadline))

View File

@ -26,8 +26,6 @@ import (
) )
var ( var (
// WithContextDialer is exported by dialoptions.go
WithContextDialer interface{} // func(context.Context, string) (net.Conn, error) grpc.DialOption
// WithResolverBuilder is exported by dialoptions.go // WithResolverBuilder is exported by dialoptions.go
WithResolverBuilder interface{} // func (resolver.Builder) grpc.DialOption WithResolverBuilder interface{} // func (resolver.Builder) grpc.DialOption
// WithHealthCheckFunc is not exported by dialoptions.go // WithHealthCheckFunc is not exported by dialoptions.go

View File

@ -76,7 +76,7 @@ func (d *delayListener) allowClientRead() {
d.cc.allowRead() d.cc.allowRead()
} }
func (d *delayListener) Dial(to time.Duration) (net.Conn, error) { func (d *delayListener) Dial(ctx context.Context) (net.Conn, error) {
if d.dialed { if d.dialed {
// Only hand out one connection (net.Dial can return more even after the // Only hand out one connection (net.Dial can return more even after the
// listener is closed). This is not thread-safe, but Dial should never be // listener is closed). This is not thread-safe, but Dial should never be
@ -84,7 +84,7 @@ func (d *delayListener) Dial(to time.Duration) (net.Conn, error) {
return nil, fmt.Errorf("no more conns") return nil, fmt.Errorf("no more conns")
} }
d.dialed = true d.dialed = true
c, err := net.DialTimeout("tcp", d.Listener.Addr().String(), to) c, err := (&net.Dialer{}).DialContext(ctx, "tcp", d.Listener.Addr().String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -137,7 +137,7 @@ func (s) TestGracefulStop(t *testing.T) {
closeCalled: make(chan struct{}), closeCalled: make(chan struct{}),
allowCloseCh: make(chan struct{}), allowCloseCh: make(chan struct{}),
} }
d := func(_ string, to time.Duration) (net.Conn, error) { return dlis.Dial(to) } d := func(ctx context.Context, _ string) (net.Conn, error) { return dlis.Dial(ctx) }
serverGotReq := make(chan struct{}) serverGotReq := make(chan struct{})
ss := &stubServer{ ss := &stubServer{
@ -180,7 +180,7 @@ func (s) TestGracefulStop(t *testing.T) {
// even though GracefulStop has closed the listener. // even though GracefulStop has closed the listener.
ctx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, dialCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer dialCancel() defer dialCancel()
cc, err := grpc.DialContext(ctx, "", grpc.WithInsecure(), grpc.WithBlock(), grpc.WithDialer(d)) cc, err := grpc.DialContext(ctx, "", grpc.WithInsecure(), grpc.WithBlock(), grpc.WithContextDialer(d))
if err != nil { if err != nil {
dlis.allowClientRead() dlis.allowClientRead()
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)