From 3617cd5ab3aa13669cf4c090de93d85ebf6e6d5f Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Tue, 12 May 2015 17:59:20 -0700 Subject: [PATCH] revert handshaker changes --- clientconn.go | 10 +++--- credentials/credentials.go | 22 +++++++++----- examples/route_guide/server/server.go | 10 +++--- grpc-auth-support.md | 2 ++ interop/server/server.go | 10 +++--- server.go | 32 ++++++++++--------- test/end2end_test.go | 12 ++++---- transport/http2_client.go | 13 +------- transport/transport.go | 1 - transport/transport_test.go | 44 ++++++++------------------- 10 files changed, 67 insertions(+), 89 deletions(-) diff --git a/clientconn.go b/clientconn.go index 99c2d0c08..4f7fb36ba 100644 --- a/clientconn.go +++ b/clientconn.go @@ -107,11 +107,11 @@ func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) Di // WithHandshaker returns a DialOption that specifies a function to perform some handshaking // with the server. It is typically used to negotiate the wire protocol version and security // protocol with the server. -func WithHandshaker(h func(conn net.Conn) (credentials.TransportAuthenticator, error)) DialOption { - return func(o *dialOptions) { - o.copts.Handshaker = h - } -} +//func WithHandshaker(h func(conn net.Conn) (credentials.TransportAuthenticator, error)) DialOption { +// return func(o *dialOptions) { +// o.copts.Handshaker = h +// } +//} // Dial creates a client connection the given target. // TODO(zhaoq): Have an option to make Dial return immediately without waiting diff --git a/credentials/credentials.go b/credentials/credentials.go index dd43cc082..fd78a0ffa 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -84,12 +84,14 @@ type ProtocolInfo struct { // TransportAuthenticator defines the common interface for all the live gRPC wire // protocols and supported transport security protocols (e.g., TLS, SSL). type TransportAuthenticator interface { - // Handshake does the authentication handshake specified by the corresponding - // authentication protocol on rawConn. - Handshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error) + // ClientHandshake does the authentication handshake specified by the corresponding + // authentication protocol on rawConn for clients. + ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error) + // ServerHandshake does the authentication handshake for servers. + ServerHandshake(rawConn net.Conn) (net.Conn, error) // NewListener creates a listener which accepts connections with requested // authentication handshake. - NewListener(lis net.Listener) net.Listener + //NewListener(lis net.Listener) net.Listener // Info provides the ProtocolInfo of this TransportAuthenticator. Info() ProtocolInfo Credentials @@ -120,7 +122,7 @@ func (timeoutError) Error() string { return "credentials: Dial timed out" } func (timeoutError) Timeout() bool { return true } func (timeoutError) Temporary() bool { return true } -func (c *tlsCreds) Handshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) { +func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) { // borrow some code from tls.DialWithDialer var errChannel chan error if timeout != 0 { @@ -152,9 +154,13 @@ func (c *tlsCreds) Handshake(addr string, rawConn net.Conn, timeout time.Duratio return conn, nil } -// NewListener creates a net.Listener using the information in tlsCreds. -func (c *tlsCreds) NewListener(lis net.Listener) net.Listener { - return tls.NewListener(lis, &c.config) +func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, error) { + conn := tls.Server(rawConn, &c.config) + if err := conn.Handshake(); err != nil { + rawConn.Close() + return nil, err + } + return conn, nil } // NewTLS uses c to construct a TransportAuthenticator based on TLS. diff --git a/examples/route_guide/server/server.go b/examples/route_guide/server/server.go index cc47fcb52..c33234e2b 100644 --- a/examples/route_guide/server/server.go +++ b/examples/route_guide/server/server.go @@ -225,15 +225,15 @@ func main() { if err != nil { grpclog.Fatalf("failed to listen: %v", err) } - grpcServer := grpc.NewServer() - pb.RegisterRouteGuideServer(grpcServer, newServer()) + var opts []grpc.ServerOption if *tls { creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile) if err != nil { grpclog.Fatalf("Failed to generate credentials %v", err) } - grpcServer.Serve(creds.NewListener(lis)) - } else { - grpcServer.Serve(lis) + opts = []grpc.ServerOption{grpc.Creds(creds)} } + grpcServer := grpc.NewServer(opts...) + pb.RegisterRouteGuideServer(grpcServer, newServer()) + grpcServer.Serve(lis) } diff --git a/grpc-auth-support.md b/grpc-auth-support.md index 80565ca78..b5ae1d087 100644 --- a/grpc-auth-support.md +++ b/grpc-auth-support.md @@ -15,6 +15,8 @@ creds, err := credentials.NewServerTLSFromFile(certFile, keyFile) if err != nil { log.Fatalf("Failed to generate credentials %v", err) } +server := grpc.NewServer(grpc.Creds(creds)) +... server.Serve(creds.NewListener(lis)) ``` diff --git a/interop/server/server.go b/interop/server/server.go index 4c79cb18a..f781c9d53 100644 --- a/interop/server/server.go +++ b/interop/server/server.go @@ -195,15 +195,15 @@ func main() { if err != nil { grpclog.Fatalf("failed to listen: %v", err) } - server := grpc.NewServer() - testpb.RegisterTestServiceServer(server, &testServer{}) + var opts []grpc.ServerOption if *useTLS { creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile) if err != nil { grpclog.Fatalf("Failed to generate credentials %v", err) } - server.Serve(creds.NewListener(lis)) - } else { - server.Serve(lis) + opts = []grpc.ServerOption{grpc.Creds(creds)} } + server := grpc.NewServer(opts...) + testpb.RegisterTestServiceServer(server, &testServer{}) + server.Serve(lis) } diff --git a/server.go b/server.go index a8be76393..5a3ee8eb7 100644 --- a/server.go +++ b/server.go @@ -44,6 +44,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" "google.golang.org/grpc/transport" @@ -85,7 +86,7 @@ type Server struct { } type options struct { - handshaker func(net.Conn) error + creds []credentials.Credentials codec Codec maxConcurrentStreams uint32 } @@ -93,14 +94,6 @@ type options struct { // A ServerOption sets options. type ServerOption func(*options) -// Handshaker returns a ServerOption that specifies a function to perform user-specified -// handshaking on the connection before it becomes usable for gRPC. -func Handshaker(f func(net.Conn) error) ServerOption { - return func(o *options) { - o.handshaker = f - } -} - // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. func CustomCodec(codec Codec) ServerOption { return func(o *options) { @@ -116,6 +109,13 @@ func MaxConcurrentStreams(n uint32) ServerOption { } } +// Creds returns a ServerOption that sets credentials for server connections. +func Creds(c credentials.Credentials) ServerOption { + return func(o *options) { + o.creds = append(o.creds, c) + } +} + // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -195,12 +195,14 @@ func (s *Server) Serve(lis net.Listener) error { if err != nil { return err } - // Perform handshaking if it is required. - if s.opts.handshaker != nil { - if err := s.opts.handshaker(c); err != nil { - grpclog.Println("grpc: Server.Serve failed to complete handshake.") - c.Close() - continue + for _, o := range s.opts.creds { + if creds, ok := o.(credentials.TransportAuthenticator); ok { + c, err = creds.ServerHandshake(c) + if err != nil { + grpclog.Println("grpc: Server.Serve failed to complete security handshake.") + continue + } + break } } s.mu.Lock() diff --git a/test/end2end_test.go b/test/end2end_test.go index ded266c7b..3f55024b3 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -284,27 +284,27 @@ func listTestEnv() []env { } func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) { - s = grpc.NewServer(grpc.MaxConcurrentStreams(maxStream)) + sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream)} la := ":0" switch e.network { case "unix": - la = "/tmp/testsock" + fmt.Sprintf("%p", s) + la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now()) syscall.Unlink(la) } lis, err := net.Listen(e.network, la) if err != nil { grpclog.Fatalf("Failed to listen: %v", err) } - testpb.RegisterTestServiceServer(s, &testServer{}) if e.security == "tls" { creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") if err != nil { grpclog.Fatalf("Failed to generate credentials %v", err) } - go s.Serve(creds.NewListener(lis)) - } else { - go s.Serve(lis) + sopts = append(sopts, grpc.Creds(creds)) } + s = grpc.NewServer(sopts...) + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) addr := la switch e.network { case "unix": diff --git a/transport/http2_client.go b/transport/http2_client.go index 0a7dd70e6..6ba934489 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -111,17 +111,6 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) } - // Perform handshake if opts.Handshaker is set. - if opts.Handshaker != nil { - auth, err := opts.Handshaker(conn) - if err != nil { - return nil, ConnectionErrorf("transport: handshaking failed %v", err) - } - // Prepend the resulting authenticator to opts.AuthOptions. - if auth != nil { - opts.AuthOptions = append([]credentials.Credentials{auth}, opts.AuthOptions...) - } - } for _, c := range opts.AuthOptions { if ccreds, ok := c.(credentials.TransportAuthenticator); ok { scheme = "https" @@ -132,7 +121,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e if timeout > 0 { timeout -= time.Since(startT) } - conn, connErr = ccreds.Handshake(addr, conn, timeout) + conn, connErr = ccreds.ClientHandshake(addr, conn, timeout) break } } diff --git a/transport/transport.go b/transport/transport.go index de2bdb41b..5dfd89f0d 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -316,7 +316,6 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv // ConnectOptions covers all relevant options for dialing a server. type ConnectOptions struct { Dialer func(string, time.Duration) (net.Conn, error) - Handshaker func(conn net.Conn) (credentials.TransportAuthenticator, error) AuthOptions []credentials.Credentials Timeout time.Duration } diff --git a/transport/transport_test.go b/transport/transport_test.go index 22fc1d7e9..adbb3e002 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -47,7 +47,6 @@ import ( "github.com/bradfitz/http2" "golang.org/x/net/context" "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" ) @@ -61,7 +60,6 @@ type server struct { } var ( - tlsDir = "testdata/" expectedRequest = []byte("ping") expectedResponse = []byte("pong") expectedRequestLarge = make([]byte, initialWindowSize*2) @@ -129,7 +127,7 @@ func (h *testStreamHandler) handleStreamMisbehave(s *Stream) { } // start starts server. Other goroutines should block on s.readyChan for futher operations. -func (s *server) start(useTLS bool, port int, maxStreams uint32, ht hType) { +func (s *server) start(port int, maxStreams uint32, ht hType) { var err error if port == 0 { s.lis, err = net.Listen("tcp", ":0") @@ -139,13 +137,6 @@ func (s *server) start(useTLS bool, port int, maxStreams uint32, ht hType) { if err != nil { grpclog.Fatalf("failed to listen: %v", err) } - if useTLS { - creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") - if err != nil { - grpclog.Fatalf("Failed to generate credentials %v", err) - } - s.lis = creds.NewListener(s.lis) - } _, p, err := net.SplitHostPort(s.lis.Addr().String()) if err != nil { grpclog.Fatalf("failed to parse listener address: %v", err) @@ -202,27 +193,16 @@ func (s *server) stop() { s.mu.Unlock() } -func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, ht hType) (*server, ClientTransport) { +func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) { server := &server{readyChan: make(chan bool)} - go server.start(useTLS, port, maxStreams, ht) + go server.start(port, maxStreams, ht) server.wait(t, 2*time.Second) addr := "localhost:" + server.port var ( ct ClientTransport connErr error ) - if useTLS { - creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") - if err != nil { - t.Fatalf("Failed to create credentials %v", err) - } - dopts := ConnectOptions{ - AuthOptions: []credentials.Credentials{creds}, - } - ct, connErr = NewClientTransport(addr, &dopts) - } else { - ct, connErr = NewClientTransport(addr, &ConnectOptions{}) - } + ct, connErr = NewClientTransport(addr, &ConnectOptions{}) if connErr != nil { t.Fatalf("failed to create transport: %v", connErr) } @@ -230,7 +210,7 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, ht hType) (*s } func TestClientSendAndReceive(t *testing.T) { - server, ct := setUp(t, true, 0, math.MaxUint32, normal) + server, ct := setUp(t, 0, math.MaxUint32, normal) callHdr := &CallHdr{ Host: "localhost", Method: "foo.Small", @@ -270,7 +250,7 @@ func TestClientSendAndReceive(t *testing.T) { } func TestClientErrorNotify(t *testing.T) { - server, ct := setUp(t, true, 0, math.MaxUint32, normal) + server, ct := setUp(t, 0, math.MaxUint32, normal) go server.stop() // ct.reader should detect the error and activate ct.Error(). <-ct.Error() @@ -304,7 +284,7 @@ func performOneRPC(ct ClientTransport) { } func TestClientMix(t *testing.T) { - s, ct := setUp(t, true, 0, math.MaxUint32, normal) + s, ct := setUp(t, 0, math.MaxUint32, normal) go func(s *server) { time.Sleep(5 * time.Second) s.stop() @@ -320,7 +300,7 @@ func TestClientMix(t *testing.T) { } func TestExceedMaxStreamsLimit(t *testing.T) { - server, ct := setUp(t, true, 0, 1, normal) + server, ct := setUp(t, 0, 1, normal) defer func() { ct.Close() server.stop() @@ -368,7 +348,7 @@ func TestExceedMaxStreamsLimit(t *testing.T) { } func TestLargeMessage(t *testing.T) { - server, ct := setUp(t, true, 0, math.MaxUint32, normal) + server, ct := setUp(t, 0, math.MaxUint32, normal) callHdr := &CallHdr{ Host: "localhost", Method: "foo.Large", @@ -402,7 +382,7 @@ func TestLargeMessage(t *testing.T) { } func TestLargeMessageSuspension(t *testing.T) { - server, ct := setUp(t, true, 0, math.MaxUint32, suspended) + server, ct := setUp(t, 0, math.MaxUint32, suspended) callHdr := &CallHdr{ Host: "localhost", Method: "foo.Large", @@ -424,7 +404,7 @@ func TestLargeMessageSuspension(t *testing.T) { } func TestServerWithMisbehavedClient(t *testing.T) { - server, ct := setUp(t, true, 0, math.MaxUint32, suspended) + server, ct := setUp(t, 0, math.MaxUint32, suspended) callHdr := &CallHdr{ Host: "localhost", Method: "foo", @@ -524,7 +504,7 @@ func TestServerWithMisbehavedClient(t *testing.T) { } func TestClientWithMisbehavedServer(t *testing.T) { - server, ct := setUp(t, true, 0, math.MaxUint32, misbehaved) + server, ct := setUp(t, 0, math.MaxUint32, misbehaved) callHdr := &CallHdr{ Host: "localhost", Method: "foo",