diff --git a/call.go b/call.go index a516017ce..e15e4f9d5 100644 --- a/call.go +++ b/call.go @@ -35,7 +35,6 @@ package grpc import ( "io" - "net" "golang.org/x/net/context" "google.golang.org/grpc/codes" @@ -114,12 +113,8 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli o.after(&c) } }() - host, _, err := net.SplitHostPort(cc.target) - if err != nil { - return toRPCErr(err) - } callHdr := &transport.CallHdr{ - Host: host, + Host: cc.authority, Method: method, } topts := &transport.Options{ diff --git a/clientconn.go b/clientconn.go index 1c6f8a146..00a3b8d87 100644 --- a/clientconn.go +++ b/clientconn.go @@ -36,6 +36,7 @@ package grpc import ( "errors" "log" + "net" "sync" "time" @@ -95,6 +96,14 @@ func WithTimeout(d time.Duration) DialOption { } } +// WithNetwork returns a DialOption that specifies the network on which +// the connection will be established. +func WithNetwork(network string) DialOption { + return func(o *dialOptions) { + o.copts.Network = network + } +} + // Dial creates a client connection the given target. // TODO(zhaoq): Have an option to make Dial return immediately without waiting // for connection to complete. @@ -108,6 +117,24 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { for _, opt := range opts { opt(&cc.dopts) } + // Validate the network type + switch cc.dopts.copts.Network { + case "": + cc.dopts.copts.Network = "tcp" // Set the default + case "tcp", "tcp4", "tcp6", "unix": + default: + return nil, net.UnknownNetworkError(cc.dopts.copts.Network) + } + cc.authority = target + // Format target for tcp. + if cc.dopts.copts.Network != "unix" { + // format target for tcp. + var err error + cc.authority, _, err = net.SplitHostPort(target) + if err != nil { + return nil, err + } + } if cc.dopts.codec == nil { // Set the default codec. cc.dopts.codec = protoCodec{} @@ -124,6 +151,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { // ClientConn represents a client connection to an RPC service. type ClientConn struct { target string + authority string dopts dialOptions shutdownChan chan struct{} diff --git a/credentials/credentials.go b/credentials/credentials.go index c45370787..fae0a302b 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -105,7 +105,7 @@ func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ n return nil, fmt.Errorf("credentials: failed to parse server address %v", err) } } - return tls.DialWithDialer(dialer, "tcp", addr, &c.config) + return tls.DialWithDialer(dialer, network, addr, &c.config) } // Dial connects to addr and performs TLS handshake. diff --git a/server.go b/server.go index 1cdf1555b..bf0ad3b34 100644 --- a/server.go +++ b/server.go @@ -371,8 +371,8 @@ func (s *Server) TestingCloseConns() { s.mu.Lock() for c := range s.conns { c.Close() - delete(s.conns, c) } + s.conns = make(map[transport.ServerTransport]bool) s.mu.Unlock() } diff --git a/stream.go b/stream.go index 2aa160e28..43fdcbecb 100644 --- a/stream.go +++ b/stream.go @@ -36,7 +36,6 @@ package grpc import ( "errors" "io" - "net" "golang.org/x/net/context" "google.golang.org/grpc/codes" @@ -95,12 +94,8 @@ type ClientStream interface { // by generated code. func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { // TODO(zhaoq): CallOption is omitted. Add support when it is needed. - host, _, err := net.SplitHostPort(cc.target) - if err != nil { - return nil, toRPCErr(err) - } callHdr := &transport.CallHdr{ - Host: host, + Host: cc.authority, Method: method, } t, _, err := cc.wait(ctx, 0) diff --git a/test/end2end_test.go b/test/end2end_test.go index cfdd68802..058617de1 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -34,12 +34,15 @@ package grpc_test import ( + "fmt" "io" "log" "math" "net" "reflect" + "runtime" "sync" + "syscall" "testing" "time" @@ -263,18 +266,32 @@ func TestReconnectTimeout(t *testing.T) { } } -func setUp(useTLS bool, maxStream uint32) (s *grpc.Server, cc *grpc.ClientConn) { - lis, err := net.Listen("tcp", ":0") +type env struct { + network string // The type of network such as tcp, unix, etc. + security string // The security protocol such as TLS, SSH, etc. +} + +func listTestEnv() []env { + if runtime.GOOS == "windows" { + return []env{env{"tcp", ""}, env{"tcp", "tls"}} + } + return []env{env{"tcp", ""}, env{"tcp", "tls"}, env{"unix", ""}, env{"unix", "tls"}} +} + +func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) { + s = grpc.NewServer(grpc.MaxConcurrentStreams(maxStream)) + la := ":0" + switch e.network { + case "unix": + la = "/tmp/testsock" + fmt.Sprintf("%p", s) + syscall.Unlink(la) + } + lis, err := net.Listen(e.network, la) if err != nil { log.Fatalf("Failed to listen: %v", err) } - _, port, err := net.SplitHostPort(lis.Addr().String()) - if err != nil { - log.Fatalf("Failed to parse listener address: %v", err) - } - s = grpc.NewServer(grpc.MaxConcurrentStreams(maxStream)) testpb.RegisterTestServiceServer(s, &testServer{}) - if useTLS { + if e.security == "tls" { creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") if err != nil { log.Fatalf("Failed to generate credentials %v", err) @@ -283,15 +300,24 @@ func setUp(useTLS bool, maxStream uint32) (s *grpc.Server, cc *grpc.ClientConn) } else { go s.Serve(lis) } - addr := "localhost:" + port - if useTLS { + addr := la + switch e.network { + case "unix": + default: + _, port, err := net.SplitHostPort(lis.Addr().String()) + if err != nil { + log.Fatalf("Failed to parse listener address: %v", err) + } + addr = "localhost:" + port + } + if e.security == "tls" { creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") if err != nil { log.Fatalf("Failed to create credentials %v", err) } - cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds)) + cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithNetwork(e.network)) } else { - cc, err = grpc.Dial(addr) + cc, err = grpc.Dial(addr, grpc.WithNetwork(e.network)) } if err != nil { log.Fatalf("Dial(%q) = %v", addr, err) @@ -305,7 +331,14 @@ func tearDown(s *grpc.Server, cc *grpc.ClientConn) { } func TestTimeoutOnDeadServer(t *testing.T) { - s, cc := setUp(false, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testTimeoutOnDeadServer(t, e) + } +} + +func testTimeoutOnDeadServer(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) s.Stop() // Set -1 as the timeout to make sure if transportMonitor gets error @@ -319,7 +352,14 @@ func TestTimeoutOnDeadServer(t *testing.T) { } func TestEmptyUnary(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testEmptyUnary(t, e) + } +} + +func testEmptyUnary(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) reply, err := tc.EmptyCall(context.Background(), &testpb.Empty{}) @@ -329,7 +369,14 @@ func TestEmptyUnary(t *testing.T) { } func TestFailedEmptyUnary(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testFailedEmptyUnary(t, e) + } +} + +func testFailedEmptyUnary(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) ctx := metadata.NewContext(context.Background(), testMetadata) @@ -339,7 +386,14 @@ func TestFailedEmptyUnary(t *testing.T) { } func TestLargeUnary(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testLargeUnary(t, e) + } +} + +func testLargeUnary(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) argSize := 271828 @@ -361,7 +415,14 @@ func TestLargeUnary(t *testing.T) { } func TestMetadataUnaryRPC(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testMetadataUnaryRPC(t, e) + } +} + +func testMetadataUnaryRPC(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) argSize := 2718 @@ -405,11 +466,18 @@ func performOneRPC(t *testing.T, tc testpb.TestServiceClient, wg *sync.WaitGroup wg.Done() } +func TestRetry(t *testing.T) { + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testRetry(t, e) + } +} + // This test mimics a user who sends 1000 RPCs concurrently on a faulty transport. // TODO(zhaoq): Refactor to make this clearer and add more cases to test racy // and error-prone paths. -func TestRetry(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) +func testRetry(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) var wg sync.WaitGroup @@ -431,9 +499,16 @@ func TestRetry(t *testing.T) { wg.Wait() } -// TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism. func TestRPCTimeout(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testRPCTimeout(t, e) + } +} + +// TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism. +func testRPCTimeout(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) argSize := 2718 @@ -456,7 +531,14 @@ func TestRPCTimeout(t *testing.T) { } func TestCancel(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testCancel(t, e) + } +} + +func testCancel(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) argSize := 2718 @@ -482,7 +564,14 @@ var ( ) func TestPingPong(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testPingPong(t, e) + } +} + +func testPingPong(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) stream, err := tc.FullDuplexCall(context.Background()) @@ -527,7 +616,14 @@ func TestPingPong(t *testing.T) { } func TestMetadataStreamingRPC(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testMetadataStreamingRPC(t, e) + } +} + +func testMetadataStreamingRPC(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) ctx := metadata.NewContext(context.Background(), testMetadata) @@ -578,7 +674,14 @@ func TestMetadataStreamingRPC(t *testing.T) { } func TestServerStreaming(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testServerStreaming(t, e) + } +} + +func testServerStreaming(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) respParam := make([]*testpb.ResponseParameters, len(respSizes)) @@ -624,7 +727,14 @@ func TestServerStreaming(t *testing.T) { } func TestFailedServerStreaming(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testFailedServerStreaming(t, e) + } +} + +func testFailedServerStreaming(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) respParam := make([]*testpb.ResponseParameters, len(respSizes)) @@ -648,7 +758,14 @@ func TestFailedServerStreaming(t *testing.T) { } func TestClientStreaming(t *testing.T) { - s, cc := setUp(true, math.MaxUint32) + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testClientStreaming(t, e) + } +} + +func testClientStreaming(t *testing.T, e env) { + s, cc := setUp(math.MaxUint32, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) stream, err := tc.StreamingInputCall(context.Background()) @@ -676,8 +793,15 @@ func TestClientStreaming(t *testing.T) { } func TestExceedMaxStreamsLimit(t *testing.T) { + for _, e := range listTestEnv() { + log.Println("Testing in the env: ", e) + testExceedMaxStreamsLimit(t, e) + } +} + +func testExceedMaxStreamsLimit(t *testing.T, e env) { // Only allows 1 live stream per server transport. - s, cc := setUp(true, 1) + s, cc := setUp(1, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) var err error diff --git a/transport/http2_client.go b/transport/http2_client.go index baeee8f6a..136debfbc 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -110,12 +110,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e // multiple ones provided. Revisit this if it is not appropriate. Probably // place the ClientTransport construction into a separate function to make // things clear. - conn, connErr = ccreds.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, "tcp", addr) + conn, connErr = ccreds.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, opts.Network, addr) break } } if scheme == "http" { - conn, connErr = net.DialTimeout("tcp", addr, opts.Timeout) + conn, connErr = net.DialTimeout(opts.Network, addr, opts.Timeout) } if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) diff --git a/transport/transport.go b/transport/transport.go index 0824f2795..ebd36291c 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -315,7 +315,9 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv // ConnectOptions covers all relevant options for dialing a server. type ConnectOptions struct { - Protocol string + // Network indicates the type of network where the connection is established. + // Known networks are "tcp", "tcp4", "tcp6", "unix" + Network string AuthOptions []credentials.Credentials Timeout time.Duration }