revert handshaker changes

This commit is contained in:
iamqizhao 2015-05-12 17:59:20 -07:00
parent 923d211a3d
commit 3617cd5ab3
10 changed files with 67 additions and 89 deletions

View File

@ -107,11 +107,11 @@ func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) Di
// WithHandshaker returns a DialOption that specifies a function to perform some handshaking // WithHandshaker returns a DialOption that specifies a function to perform some handshaking
// with the server. It is typically used to negotiate the wire protocol version and security // with the server. It is typically used to negotiate the wire protocol version and security
// protocol with the server. // protocol with the server.
func WithHandshaker(h func(conn net.Conn) (credentials.TransportAuthenticator, error)) DialOption { //func WithHandshaker(h func(conn net.Conn) (credentials.TransportAuthenticator, error)) DialOption {
return func(o *dialOptions) { // return func(o *dialOptions) {
o.copts.Handshaker = h // o.copts.Handshaker = h
} // }
} //}
// Dial creates a client connection the given target. // Dial creates a client connection the given target.
// TODO(zhaoq): Have an option to make Dial return immediately without waiting // TODO(zhaoq): Have an option to make Dial return immediately without waiting

View File

@ -84,12 +84,14 @@ type ProtocolInfo struct {
// TransportAuthenticator defines the common interface for all the live gRPC wire // TransportAuthenticator defines the common interface for all the live gRPC wire
// protocols and supported transport security protocols (e.g., TLS, SSL). // protocols and supported transport security protocols (e.g., TLS, SSL).
type TransportAuthenticator interface { type TransportAuthenticator interface {
// Handshake does the authentication handshake specified by the corresponding // ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn. // authentication protocol on rawConn for clients.
Handshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error)
// ServerHandshake does the authentication handshake for servers.
ServerHandshake(rawConn net.Conn) (net.Conn, error)
// NewListener creates a listener which accepts connections with requested // NewListener creates a listener which accepts connections with requested
// authentication handshake. // authentication handshake.
NewListener(lis net.Listener) net.Listener //NewListener(lis net.Listener) net.Listener
// Info provides the ProtocolInfo of this TransportAuthenticator. // Info provides the ProtocolInfo of this TransportAuthenticator.
Info() ProtocolInfo Info() ProtocolInfo
Credentials Credentials
@ -120,7 +122,7 @@ func (timeoutError) Error() string { return "credentials: Dial timed out" }
func (timeoutError) Timeout() bool { return true } func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true } func (timeoutError) Temporary() bool { return true }
func (c *tlsCreds) Handshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) { func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) {
// borrow some code from tls.DialWithDialer // borrow some code from tls.DialWithDialer
var errChannel chan error var errChannel chan error
if timeout != 0 { if timeout != 0 {
@ -152,9 +154,13 @@ func (c *tlsCreds) Handshake(addr string, rawConn net.Conn, timeout time.Duratio
return conn, nil return conn, nil
} }
// NewListener creates a net.Listener using the information in tlsCreds. func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, error) {
func (c *tlsCreds) NewListener(lis net.Listener) net.Listener { conn := tls.Server(rawConn, &c.config)
return tls.NewListener(lis, &c.config) if err := conn.Handshake(); err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
} }
// NewTLS uses c to construct a TransportAuthenticator based on TLS. // NewTLS uses c to construct a TransportAuthenticator based on TLS.

View File

@ -225,15 +225,15 @@ func main() {
if err != nil { if err != nil {
grpclog.Fatalf("failed to listen: %v", err) grpclog.Fatalf("failed to listen: %v", err)
} }
grpcServer := grpc.NewServer() var opts []grpc.ServerOption
pb.RegisterRouteGuideServer(grpcServer, newServer())
if *tls { if *tls {
creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile) creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile)
if err != nil { if err != nil {
grpclog.Fatalf("Failed to generate credentials %v", err) grpclog.Fatalf("Failed to generate credentials %v", err)
} }
grpcServer.Serve(creds.NewListener(lis)) opts = []grpc.ServerOption{grpc.Creds(creds)}
} else {
grpcServer.Serve(lis)
} }
grpcServer := grpc.NewServer(opts...)
pb.RegisterRouteGuideServer(grpcServer, newServer())
grpcServer.Serve(lis)
} }

View File

@ -15,6 +15,8 @@ creds, err := credentials.NewServerTLSFromFile(certFile, keyFile)
if err != nil { if err != nil {
log.Fatalf("Failed to generate credentials %v", err) log.Fatalf("Failed to generate credentials %v", err)
} }
server := grpc.NewServer(grpc.Creds(creds))
...
server.Serve(creds.NewListener(lis)) server.Serve(creds.NewListener(lis))
``` ```

View File

@ -195,15 +195,15 @@ func main() {
if err != nil { if err != nil {
grpclog.Fatalf("failed to listen: %v", err) grpclog.Fatalf("failed to listen: %v", err)
} }
server := grpc.NewServer() var opts []grpc.ServerOption
testpb.RegisterTestServiceServer(server, &testServer{})
if *useTLS { if *useTLS {
creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile) creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile)
if err != nil { if err != nil {
grpclog.Fatalf("Failed to generate credentials %v", err) grpclog.Fatalf("Failed to generate credentials %v", err)
} }
server.Serve(creds.NewListener(lis)) opts = []grpc.ServerOption{grpc.Creds(creds)}
} else {
server.Serve(lis)
} }
server := grpc.NewServer(opts...)
testpb.RegisterTestServiceServer(server, &testServer{})
server.Serve(lis)
} }

View File

@ -44,6 +44,7 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
@ -85,7 +86,7 @@ type Server struct {
} }
type options struct { type options struct {
handshaker func(net.Conn) error creds []credentials.Credentials
codec Codec codec Codec
maxConcurrentStreams uint32 maxConcurrentStreams uint32
} }
@ -93,14 +94,6 @@ type options struct {
// A ServerOption sets options. // A ServerOption sets options.
type ServerOption func(*options) type ServerOption func(*options)
// Handshaker returns a ServerOption that specifies a function to perform user-specified
// handshaking on the connection before it becomes usable for gRPC.
func Handshaker(f func(net.Conn) error) ServerOption {
return func(o *options) {
o.handshaker = f
}
}
// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
func CustomCodec(codec Codec) ServerOption { func CustomCodec(codec Codec) ServerOption {
return func(o *options) { return func(o *options) {
@ -116,6 +109,13 @@ func MaxConcurrentStreams(n uint32) ServerOption {
} }
} }
// Creds returns a ServerOption that sets credentials for server connections.
func Creds(c credentials.Credentials) ServerOption {
return func(o *options) {
o.creds = append(o.creds, c)
}
}
// NewServer creates a gRPC server which has no service registered and has not // NewServer creates a gRPC server which has no service registered and has not
// started to accept requests yet. // started to accept requests yet.
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
@ -195,12 +195,14 @@ func (s *Server) Serve(lis net.Listener) error {
if err != nil { if err != nil {
return err return err
} }
// Perform handshaking if it is required. for _, o := range s.opts.creds {
if s.opts.handshaker != nil { if creds, ok := o.(credentials.TransportAuthenticator); ok {
if err := s.opts.handshaker(c); err != nil { c, err = creds.ServerHandshake(c)
grpclog.Println("grpc: Server.Serve failed to complete handshake.") if err != nil {
c.Close() grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
continue continue
}
break
} }
} }
s.mu.Lock() s.mu.Lock()

View File

@ -284,27 +284,27 @@ func listTestEnv() []env {
} }
func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) { func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) {
s = grpc.NewServer(grpc.MaxConcurrentStreams(maxStream)) sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream)}
la := ":0" la := ":0"
switch e.network { switch e.network {
case "unix": case "unix":
la = "/tmp/testsock" + fmt.Sprintf("%p", s) la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now())
syscall.Unlink(la) syscall.Unlink(la)
} }
lis, err := net.Listen(e.network, la) lis, err := net.Listen(e.network, la)
if err != nil { if err != nil {
grpclog.Fatalf("Failed to listen: %v", err) grpclog.Fatalf("Failed to listen: %v", err)
} }
testpb.RegisterTestServiceServer(s, &testServer{})
if e.security == "tls" { if e.security == "tls" {
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil { if err != nil {
grpclog.Fatalf("Failed to generate credentials %v", err) grpclog.Fatalf("Failed to generate credentials %v", err)
} }
go s.Serve(creds.NewListener(lis)) sopts = append(sopts, grpc.Creds(creds))
} else {
go s.Serve(lis)
} }
s = grpc.NewServer(sopts...)
testpb.RegisterTestServiceServer(s, &testServer{})
go s.Serve(lis)
addr := la addr := la
switch e.network { switch e.network {
case "unix": case "unix":

View File

@ -111,17 +111,6 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
if connErr != nil { if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf("transport: %v", connErr)
} }
// Perform handshake if opts.Handshaker is set.
if opts.Handshaker != nil {
auth, err := opts.Handshaker(conn)
if err != nil {
return nil, ConnectionErrorf("transport: handshaking failed %v", err)
}
// Prepend the resulting authenticator to opts.AuthOptions.
if auth != nil {
opts.AuthOptions = append([]credentials.Credentials{auth}, opts.AuthOptions...)
}
}
for _, c := range opts.AuthOptions { for _, c := range opts.AuthOptions {
if ccreds, ok := c.(credentials.TransportAuthenticator); ok { if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
scheme = "https" scheme = "https"
@ -132,7 +121,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
if timeout > 0 { if timeout > 0 {
timeout -= time.Since(startT) timeout -= time.Since(startT)
} }
conn, connErr = ccreds.Handshake(addr, conn, timeout) conn, connErr = ccreds.ClientHandshake(addr, conn, timeout)
break break
} }
} }

View File

@ -316,7 +316,6 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
// ConnectOptions covers all relevant options for dialing a server. // ConnectOptions covers all relevant options for dialing a server.
type ConnectOptions struct { type ConnectOptions struct {
Dialer func(string, time.Duration) (net.Conn, error) Dialer func(string, time.Duration) (net.Conn, error)
Handshaker func(conn net.Conn) (credentials.TransportAuthenticator, error)
AuthOptions []credentials.Credentials AuthOptions []credentials.Credentials
Timeout time.Duration Timeout time.Duration
} }

View File

@ -47,7 +47,6 @@ import (
"github.com/bradfitz/http2" "github.com/bradfitz/http2"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
) )
@ -61,7 +60,6 @@ type server struct {
} }
var ( var (
tlsDir = "testdata/"
expectedRequest = []byte("ping") expectedRequest = []byte("ping")
expectedResponse = []byte("pong") expectedResponse = []byte("pong")
expectedRequestLarge = make([]byte, initialWindowSize*2) expectedRequestLarge = make([]byte, initialWindowSize*2)
@ -129,7 +127,7 @@ func (h *testStreamHandler) handleStreamMisbehave(s *Stream) {
} }
// start starts server. Other goroutines should block on s.readyChan for futher operations. // start starts server. Other goroutines should block on s.readyChan for futher operations.
func (s *server) start(useTLS bool, port int, maxStreams uint32, ht hType) { func (s *server) start(port int, maxStreams uint32, ht hType) {
var err error var err error
if port == 0 { if port == 0 {
s.lis, err = net.Listen("tcp", ":0") s.lis, err = net.Listen("tcp", ":0")
@ -139,13 +137,6 @@ func (s *server) start(useTLS bool, port int, maxStreams uint32, ht hType) {
if err != nil { if err != nil {
grpclog.Fatalf("failed to listen: %v", err) grpclog.Fatalf("failed to listen: %v", err)
} }
if useTLS {
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
grpclog.Fatalf("Failed to generate credentials %v", err)
}
s.lis = creds.NewListener(s.lis)
}
_, p, err := net.SplitHostPort(s.lis.Addr().String()) _, p, err := net.SplitHostPort(s.lis.Addr().String())
if err != nil { if err != nil {
grpclog.Fatalf("failed to parse listener address: %v", err) grpclog.Fatalf("failed to parse listener address: %v", err)
@ -202,27 +193,16 @@ func (s *server) stop() {
s.mu.Unlock() s.mu.Unlock()
} }
func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, ht hType) (*server, ClientTransport) { func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) {
server := &server{readyChan: make(chan bool)} server := &server{readyChan: make(chan bool)}
go server.start(useTLS, port, maxStreams, ht) go server.start(port, maxStreams, ht)
server.wait(t, 2*time.Second) server.wait(t, 2*time.Second)
addr := "localhost:" + server.port addr := "localhost:" + server.port
var ( var (
ct ClientTransport ct ClientTransport
connErr error connErr error
) )
if useTLS { ct, connErr = NewClientTransport(addr, &ConnectOptions{})
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
dopts := ConnectOptions{
AuthOptions: []credentials.Credentials{creds},
}
ct, connErr = NewClientTransport(addr, &dopts)
} else {
ct, connErr = NewClientTransport(addr, &ConnectOptions{})
}
if connErr != nil { if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr) t.Fatalf("failed to create transport: %v", connErr)
} }
@ -230,7 +210,7 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, ht hType) (*s
} }
func TestClientSendAndReceive(t *testing.T) { func TestClientSendAndReceive(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, normal) server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{ callHdr := &CallHdr{
Host: "localhost", Host: "localhost",
Method: "foo.Small", Method: "foo.Small",
@ -270,7 +250,7 @@ func TestClientSendAndReceive(t *testing.T) {
} }
func TestClientErrorNotify(t *testing.T) { func TestClientErrorNotify(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, normal) server, ct := setUp(t, 0, math.MaxUint32, normal)
go server.stop() go server.stop()
// ct.reader should detect the error and activate ct.Error(). // ct.reader should detect the error and activate ct.Error().
<-ct.Error() <-ct.Error()
@ -304,7 +284,7 @@ func performOneRPC(ct ClientTransport) {
} }
func TestClientMix(t *testing.T) { func TestClientMix(t *testing.T) {
s, ct := setUp(t, true, 0, math.MaxUint32, normal) s, ct := setUp(t, 0, math.MaxUint32, normal)
go func(s *server) { go func(s *server) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
s.stop() s.stop()
@ -320,7 +300,7 @@ func TestClientMix(t *testing.T) {
} }
func TestExceedMaxStreamsLimit(t *testing.T) { func TestExceedMaxStreamsLimit(t *testing.T) {
server, ct := setUp(t, true, 0, 1, normal) server, ct := setUp(t, 0, 1, normal)
defer func() { defer func() {
ct.Close() ct.Close()
server.stop() server.stop()
@ -368,7 +348,7 @@ func TestExceedMaxStreamsLimit(t *testing.T) {
} }
func TestLargeMessage(t *testing.T) { func TestLargeMessage(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, normal) server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{ callHdr := &CallHdr{
Host: "localhost", Host: "localhost",
Method: "foo.Large", Method: "foo.Large",
@ -402,7 +382,7 @@ func TestLargeMessage(t *testing.T) {
} }
func TestLargeMessageSuspension(t *testing.T) { func TestLargeMessageSuspension(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, suspended) server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{ callHdr := &CallHdr{
Host: "localhost", Host: "localhost",
Method: "foo.Large", Method: "foo.Large",
@ -424,7 +404,7 @@ func TestLargeMessageSuspension(t *testing.T) {
} }
func TestServerWithMisbehavedClient(t *testing.T) { func TestServerWithMisbehavedClient(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, suspended) server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{ callHdr := &CallHdr{
Host: "localhost", Host: "localhost",
Method: "foo", Method: "foo",
@ -524,7 +504,7 @@ func TestServerWithMisbehavedClient(t *testing.T) {
} }
func TestClientWithMisbehavedServer(t *testing.T) { func TestClientWithMisbehavedServer(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, misbehaved) server, ct := setUp(t, 0, math.MaxUint32, misbehaved)
callHdr := &CallHdr{ callHdr := &CallHdr{
Host: "localhost", Host: "localhost",
Method: "foo", Method: "foo",