From 2342e3866997f91c848f15bc4afe6827f70dcaab Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Wed, 27 Jul 2016 16:52:39 -0400 Subject: [PATCH 1/2] test,transport: simplify --- clientconn.go | 7 ++++--- test/end2end_test.go | 24 ++++++------------------ transport/http2_client.go | 17 +++++++++-------- transport/transport.go | 2 +- transport/transport_test.go | 2 +- 5 files changed, 21 insertions(+), 31 deletions(-) diff --git a/clientconn.go b/clientconn.go index 3206d6747..e43a8f721 100644 --- a/clientconn.go +++ b/clientconn.go @@ -558,12 +558,13 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { t.Close() } sleepTime := ac.dopts.bs.backoff(retries) - ac.dopts.copts.Timeout = sleepTime + copts := ac.dopts.copts + copts.Timeout = sleepTime if sleepTime < minConnectTimeout { - ac.dopts.copts.Timeout = minConnectTimeout + copts.Timeout = minConnectTimeout } connectTime := time.Now() - newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts) + newTransport, err := transport.NewClientTransport(ac.addr.Addr, copts) if err != nil { ac.mu.Lock() if ac.state == Shutdown { diff --git a/test/end2end_test.go b/test/end2end_test.go index cdbc4c555..769ad0646 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -300,39 +300,29 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ const tlsDir = "testdata/" -func unixDialer(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("unix", addr, timeout) -} - type env struct { name string network string // The type of network such as tcp, unix, etc. - dialer func(addr string, timeout time.Duration) (net.Conn, error) security string // The security protocol such as TLS, SSH, etc. httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS } func (e env) runnable() bool { - if runtime.GOOS == "windows" && strings.HasPrefix(e.name, "unix-") { + if runtime.GOOS == "windows" && e.network == "unix" { return false } return true } -func (e env) getDialer() func(addr string, timeout time.Duration) (net.Conn, error) { - if e.dialer != nil { - return e.dialer - } - return func(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("tcp", addr, timeout) - } +func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout(e.network, addr, timeout) } var ( tcpClearEnv = env{name: "tcp-clear", network: "tcp"} tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"} - unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer} - unixTLSEnv = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"} + unixClearEnv = env{name: "unix-clear", network: "unix"} + unixTLSEnv = env{name: "unix-tls", network: "unix", security: "tls"} handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true} allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv} ) @@ -515,9 +505,7 @@ func (te *test) declareLogNoise(phrases ...string) { } func (te *test) withServerTester(fn func(st *serverTester)) { - var c net.Conn - var err error - c, err = te.e.getDialer()(te.srvAddr, 10*time.Second) + c, err := te.e.dialer(te.srvAddr, 10*time.Second) if err != nil { te.t.Fatal(err) } diff --git a/transport/http2_client.go b/transport/http2_client.go index 51cf17920..5fd5895fd 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -107,20 +107,21 @@ type http2Client struct { prevGoAwayID uint32 } +func dial(fn func(string, time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error) { + if fn != nil { + return fn(addr, timeout) + } + return net.DialTimeout("tcp", addr, timeout) +} + // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { - if opts.Dialer == nil { - // Set the default Dialer. - opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("tcp", addr, timeout) - } - } +func newHTTP2Client(addr string, opts ConnectOptions) (_ ClientTransport, err error) { scheme := "http" startT := time.Now() timeout := opts.Timeout - conn, connErr := opts.Dialer(addr, timeout) + conn, connErr := dial(opts.Dialer, addr, timeout) if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) } diff --git a/transport/transport.go b/transport/transport.go index c41436e1a..31a90bf78 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -366,7 +366,7 @@ type ConnectOptions struct { // NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) { +func NewClientTransport(target string, opts ConnectOptions) (ClientTransport, error) { return newHTTP2Client(target, opts) } diff --git a/transport/transport_test.go b/transport/transport_test.go index 6dd01d851..ecfc2ca4a 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -221,7 +221,7 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client ct ClientTransport connErr error ) - ct, connErr = NewClientTransport(addr, &ConnectOptions{}) + ct, connErr = NewClientTransport(addr, ConnectOptions{}) if connErr != nil { t.Fatalf("failed to create transport: %v", connErr) } From 61f3f61ef030590dd1a6cfc861cdc75ba5c9c542 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Tue, 5 Jul 2016 14:56:57 -0400 Subject: [PATCH 2/2] cancel outgoing net.Dial when ClientConn is closed --- clientconn.go | 29 ++++++++++++------------- test/end2end_test.go | 10 +++++++-- transport/go16.go | 45 +++++++++++++++++++++++++++++++++++++++ transport/http2_client.go | 8 +++---- transport/pre_go16.go | 45 +++++++++++++++++++++++++++++++++++++++ transport/transport.go | 4 +++- 6 files changed, 119 insertions(+), 22 deletions(-) create mode 100644 transport/go16.go create mode 100644 transport/pre_go16.go diff --git a/clientconn.go b/clientconn.go index e43a8f721..214fb9005 100644 --- a/clientconn.go +++ b/clientconn.go @@ -196,7 +196,7 @@ func WithTimeout(d time.Duration) DialOption { } // WithDialer returns a DialOption that specifies a function to use for dialing network addresses. -func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption { +func WithDialer(f func(string, time.Duration, <-chan struct{}) (net.Conn, error)) DialOption { return func(o *dialOptions) { o.copts.Dialer = f } @@ -361,11 +361,11 @@ func (cc *ClientConn) lbWatcher() { func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { ac := &addrConn{ - cc: cc, - addr: addr, - dopts: cc.dopts, - shutdownChan: make(chan struct{}), + cc: cc, + addr: addr, + dopts: cc.dopts, } + ac.dopts.copts.Cancel = make(chan struct{}) if EnableTracing { ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) } @@ -468,11 +468,10 @@ func (cc *ClientConn) Close() error { // addrConn is a network connection to a given address. type addrConn struct { - cc *ClientConn - addr Address - dopts dialOptions - shutdownChan chan struct{} - events trace.EventLog + cc *ClientConn + addr Address + dopts dialOptions + events trace.EventLog mu sync.Mutex state ConnectivityState @@ -587,7 +586,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { closeTransport = false select { case <-time.After(sleepTime): - case <-ac.shutdownChan: + case <-ac.dopts.copts.Cancel: } retries++ grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr) @@ -622,9 +621,9 @@ func (ac *addrConn) transportMonitor() { t := ac.transport ac.mu.Unlock() select { - // shutdownChan is needed to detect the teardown when + // Cancel is needed to detect the teardown when // the addrConn is idle (i.e., no RPC in flight). - case <-ac.shutdownChan: + case <-ac.dopts.copts.Cancel: return case <-t.GoAway(): ac.tearDown(errConnDrain) @@ -725,8 +724,8 @@ func (ac *addrConn) tearDown(err error) { if ac.transport != nil && err != errConnDrain { ac.transport.Close() } - if ac.shutdownChan != nil { - close(ac.shutdownChan) + if ac.dopts.copts.Cancel != nil { + close(ac.dopts.copts.Cancel) } return } diff --git a/test/end2end_test.go b/test/end2end_test.go index 769ad0646..81bc15590 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -314,7 +314,13 @@ func (e env) runnable() bool { return true } -func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) { +func (e env) dialer(addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) { + // NB: Go 1.6 added a Cancel field on net.Dialer, which would allow this + // to be written as + // + // `(&net.Dialer{Cancel: cancel, Timeout: timeout}).Dial(e.network, addr)` + // + // but that would break compatibility with earlier Go versions. return net.DialTimeout(e.network, addr, timeout) } @@ -505,7 +511,7 @@ func (te *test) declareLogNoise(phrases ...string) { } func (te *test) withServerTester(fn func(st *serverTester)) { - c, err := te.e.dialer(te.srvAddr, 10*time.Second) + c, err := te.e.dialer(te.srvAddr, 10*time.Second, nil) if err != nil { te.t.Fatal(err) } diff --git a/transport/go16.go b/transport/go16.go new file mode 100644 index 000000000..c0d051ef9 --- /dev/null +++ b/transport/go16.go @@ -0,0 +1,45 @@ +// +build go1.6 + +/* + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "net" + "time" +) + +// newDialer constructs a net.Dialer. +func newDialer(timeout time.Duration, cancel <-chan struct{}) *net.Dialer { + return &net.Dialer{Cancel: cancel, Timeout: timeout} +} diff --git a/transport/http2_client.go b/transport/http2_client.go index 5fd5895fd..2f5d5a809 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -107,11 +107,11 @@ type http2Client struct { prevGoAwayID uint32 } -func dial(fn func(string, time.Duration) (net.Conn, error), addr string, timeout time.Duration) (net.Conn, error) { +func dial(fn func(string, time.Duration, <-chan struct{}) (net.Conn, error), addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) { if fn != nil { - return fn(addr, timeout) + return fn(addr, timeout, cancel) } - return net.DialTimeout("tcp", addr, timeout) + return newDialer(timeout, cancel).Dial("tcp", addr) } // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 @@ -121,7 +121,7 @@ func newHTTP2Client(addr string, opts ConnectOptions) (_ ClientTransport, err er scheme := "http" startT := time.Now() timeout := opts.Timeout - conn, connErr := dial(opts.Dialer, addr, timeout) + conn, connErr := dial(opts.Dialer, addr, timeout, opts.Cancel) if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) } diff --git a/transport/pre_go16.go b/transport/pre_go16.go new file mode 100644 index 000000000..126bfbd80 --- /dev/null +++ b/transport/pre_go16.go @@ -0,0 +1,45 @@ +// +build !go1.6 + +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "net" + "time" +) + +// newDialer constructs a net.Dialer. +func newDialer(timeout time.Duration, _ <-chan struct{}) *net.Dialer { + return &net.Dialer{Timeout: timeout} +} diff --git a/transport/transport.go b/transport/transport.go index 31a90bf78..104c1c1f1 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -354,8 +354,10 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authI type ConnectOptions struct { // UserAgent is the application user agent. UserAgent string + // Cancel is closed to indicate that dialing should be cancelled. + Cancel chan struct{} // Dialer specifies how to dial a network address. - Dialer func(string, time.Duration) (net.Conn, error) + Dialer func(string, time.Duration, <-chan struct{}) (net.Conn, error) // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. PerRPCCredentials []credentials.PerRPCCredentials // TransportCredentials stores the Authenticator required to setup a client connection.