This commit is contained in:
iamqizhao 2016-08-26 13:51:46 -07:00
commit 61f62e0da6
4 changed files with 67 additions and 9 deletions

View File

@ -250,13 +250,13 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
defer func() {
select {
case <-ctx.Done():
if conn != nil {
conn.Close()
}
conn = nil
err = ctx.Err()
conn, err = nil, ctx.Err()
default:
}
if err != nil {
cc.Close()
}
}()
for _, opt := range opts {
@ -312,11 +312,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
return nil, ctx.Err()
case err := <-waitC:
if err != nil {
cc.Close()
return nil, err
}
case <-timeoutCh:
cc.Close()
return nil, ErrClientConnTimeout
}
// If balancer is nil or balancer.Notify() is nil, ok will be false here.

View File

@ -40,6 +40,7 @@ package credentials // import "google.golang.org/grpc/credentials"
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
@ -86,6 +87,12 @@ type AuthInfo interface {
AuthType() string
}
var (
// ErrConnDispatched indicates that rawConn has been dispatched out of gRPC
// and the caller should not close rawConn.
ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC")
)
// TransportCredentials defines the common interface for all the live gRPC wire
// protocols and supported transport security protocols (e.g., TLS, SSL).
type TransportCredentials interface {

View File

@ -367,7 +367,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
s.mu.Unlock()
grpclog.Printf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
rawConn.Close()
// If serverHandShake returns ErrConnDispatched, keep rawConn open.
if err != credentials.ErrConnDispatched {
rawConn.Close()
}
return
}

View File

@ -848,9 +848,11 @@ func testFailFast(t *testing.T, e env) {
te.srv.Stop()
// Loop until the server teardown is propagated to the client.
for {
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) == codes.Unavailable {
_, err := tc.EmptyCall(context.Background(), &testpb.Empty{})
if grpc.Code(err) == codes.Unavailable {
break
}
fmt.Printf("%v.EmptyCall(_, _) = _, %v", tc, err)
time.Sleep(10 * time.Millisecond)
}
// The client keeps reconnecting and ongoing fail-fast RPCs should fail with code.Unavailable.
@ -2462,6 +2464,54 @@ func TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) {
}
}
type serverDispatchCred struct {
ready chan struct{}
rawConn net.Conn
}
func newServerDispatchCred() *serverDispatchCred {
return &serverDispatchCred{
ready: make(chan struct{}),
}
}
func (c *serverDispatchCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, nil, nil
}
func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
c.rawConn = rawConn
close(c.ready)
return nil, nil, credentials.ErrConnDispatched
}
func (c *serverDispatchCred) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (c *serverDispatchCred) getRawConn() net.Conn {
<-c.ready
return c.rawConn
}
func TestServerCredsDispatch(t *testing.T) {
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
cred := newServerDispatchCred()
s := grpc.NewServer(grpc.Creds(cred))
go s.Serve(lis)
defer s.Stop()
cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(cred))
if err != nil {
t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
}
defer cc.Close()
// Check rawConn is not closed.
if n, err := cred.getRawConn().Write([]byte{0}); n <= 0 || err != nil {
t.Errorf("Read() = %v, %v; want n>0, <nil>", n, err)
}
}
// interestingGoroutines returns all goroutines we care about for the purpose
// of leak checking. It excludes testing or runtime ones.
func interestingGoroutines() (gs []string) {