mirror of https://github.com/grpc/grpc-go.git
				
				
				
			Add a ServeHTTP method to *grpc.Server
This adds new http.Handler-based ServerTransport in the process, reusing the HTTP/2 server code in x/net/http2 or Go 1.6+. All end2end tests pass with this new ServerTransport. Fixes grpc/grpc-go#75 Also: Updates grpc/grpc-go#495 (lets user fix it with middleware in front) Updates grpc/grpc-go#468 (x/net/http2 validates) Updates grpc/grpc-go#147 (possible with x/net/http2) Updates grpc/grpc-go#104 (x/net/http2 does this)
This commit is contained in:
		
							parent
							
								
									3c4302b713
								
							
						
					
					
						commit
						7346c871b0
					
				| 
						 | 
				
			
			@ -273,7 +273,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
 | 
			
		|||
	case compressionNone:
 | 
			
		||||
	case compressionMade:
 | 
			
		||||
		if recvCompress == "" {
 | 
			
		||||
			return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
 | 
			
		||||
			return transport.StreamErrorf(codes.InvalidArgument, "grpc: invalid grpc-encoding %q with compression enabled", recvCompress)
 | 
			
		||||
		}
 | 
			
		||||
		if dc == nil || recvCompress != dc.Type() {
 | 
			
		||||
			return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										161
									
								
								server.go
								
								
								
								
							
							
						
						
									
										161
									
								
								server.go
								
								
								
								
							| 
						 | 
				
			
			@ -39,6 +39,7 @@ import (
 | 
			
		|||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strings"
 | 
			
		||||
| 
						 | 
				
			
			@ -46,6 +47,7 @@ import (
 | 
			
		|||
	"time"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
	"golang.org/x/net/http2"
 | 
			
		||||
	"golang.org/x/net/trace"
 | 
			
		||||
	"google.golang.org/grpc/codes"
 | 
			
		||||
	"google.golang.org/grpc/credentials"
 | 
			
		||||
| 
						 | 
				
			
			@ -83,9 +85,10 @@ type service struct {
 | 
			
		|||
// Server is a gRPC server to serve RPC requests.
 | 
			
		||||
type Server struct {
 | 
			
		||||
	opts options
 | 
			
		||||
	mu     sync.Mutex
 | 
			
		||||
 | 
			
		||||
	mu     sync.Mutex // guards following
 | 
			
		||||
	lis    map[net.Listener]bool
 | 
			
		||||
	conns  map[transport.ServerTransport]bool
 | 
			
		||||
	conns  map[io.Closer]bool
 | 
			
		||||
	m      map[string]*service // service name -> service info
 | 
			
		||||
	events trace.EventLog
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -96,6 +99,7 @@ type options struct {
 | 
			
		|||
	cp                   Compressor
 | 
			
		||||
	dc                   Decompressor
 | 
			
		||||
	maxConcurrentStreams uint32
 | 
			
		||||
	useHandlerImpl       bool // use http.Handler-based server
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// A ServerOption sets options.
 | 
			
		||||
| 
						 | 
				
			
			@ -149,7 +153,7 @@ func NewServer(opt ...ServerOption) *Server {
 | 
			
		|||
	s := &Server{
 | 
			
		||||
		lis:   make(map[net.Listener]bool),
 | 
			
		||||
		opts:  opts,
 | 
			
		||||
		conns: make(map[transport.ServerTransport]bool),
 | 
			
		||||
		conns: make(map[io.Closer]bool),
 | 
			
		||||
		m:     make(map[string]*service),
 | 
			
		||||
	}
 | 
			
		||||
	if EnableTracing {
 | 
			
		||||
| 
						 | 
				
			
			@ -216,9 +220,17 @@ var (
 | 
			
		|||
	ErrServerStopped = errors.New("grpc: the server has been stopped")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
 | 
			
		||||
	creds, ok := s.opts.creds.(credentials.TransportAuthenticator)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return rawConn, nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	return creds.ServerHandshake(rawConn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Serve accepts incoming connections on the listener lis, creating a new
 | 
			
		||||
// ServerTransport and service goroutine for each. The service goroutines
 | 
			
		||||
// read gRPC request and then call the registered handlers to reply to them.
 | 
			
		||||
// read gRPC requests and then call the registered handlers to reply to them.
 | 
			
		||||
// Service returns when lis.Accept fails.
 | 
			
		||||
func (s *Server) Serve(lis net.Listener) error {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
| 
						 | 
				
			
			@ -235,39 +247,54 @@ func (s *Server) Serve(lis net.Listener) error {
 | 
			
		|||
		delete(s.lis, lis)
 | 
			
		||||
		s.mu.Unlock()
 | 
			
		||||
	}()
 | 
			
		||||
	listenerAddr := lis.Addr()
 | 
			
		||||
	for {
 | 
			
		||||
		c, err := lis.Accept()
 | 
			
		||||
		rawConn, err := lis.Accept()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			s.mu.Lock()
 | 
			
		||||
			s.printf("done serving; Accept = %v", err)
 | 
			
		||||
			s.mu.Unlock()
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		var authInfo credentials.AuthInfo
 | 
			
		||||
		if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok {
 | 
			
		||||
			var conn net.Conn
 | 
			
		||||
			conn, authInfo, err = creds.ServerHandshake(c)
 | 
			
		||||
		// Start a new goroutine to deal with rawConn
 | 
			
		||||
		// so we don't stall this Accept loop goroutine.
 | 
			
		||||
		go s.handleRawConn(listenerAddr, rawConn)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// handleRawConn is run in its own goroutine and handles a just-accepted
 | 
			
		||||
// connection that has not had any I/O performed on it yet.
 | 
			
		||||
func (s *Server) handleRawConn(listenerAddr net.Addr, rawConn net.Conn) {
 | 
			
		||||
	conn, authInfo, err := s.useTransportAuthenticator(rawConn)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.mu.Lock()
 | 
			
		||||
				s.errorf("ServerHandshake(%q) failed: %v", c.RemoteAddr(), err)
 | 
			
		||||
		s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
 | 
			
		||||
		s.mu.Unlock()
 | 
			
		||||
		grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			c = conn
 | 
			
		||||
		rawConn.Close()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	if s.conns == nil {
 | 
			
		||||
		s.mu.Unlock()
 | 
			
		||||
			c.Close()
 | 
			
		||||
			return nil
 | 
			
		||||
		conn.Close()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	s.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
		go s.serveNewHTTP2Transport(c, authInfo)
 | 
			
		||||
	if s.opts.useHandlerImpl {
 | 
			
		||||
		s.serveUsingHandler(listenerAddr, conn)
 | 
			
		||||
	} else {
 | 
			
		||||
		s.serveNewHTTP2Transport(conn, authInfo)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// serveNewHTTP2Transport sets up a new http/2 transport (using the
 | 
			
		||||
// gRPC http2 server transport in transport/http2_server.go) and
 | 
			
		||||
// serves streams on it.
 | 
			
		||||
// This is run in its own goroutine (it does network I/O in
 | 
			
		||||
// transport.NewServerTransport).
 | 
			
		||||
func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
 | 
			
		||||
	st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -299,6 +326,59 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
 | 
			
		|||
	wg.Wait()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ http.Handler = (*Server)(nil)
 | 
			
		||||
 | 
			
		||||
// serveUsingHandler is called from handleRawConn when s is configured
 | 
			
		||||
// to handle requests via the http.Handler interface. It sets up a
 | 
			
		||||
// net/http.Server to handle the just-accepted conn. The http.Server
 | 
			
		||||
// is configured to route all incoming requests (all HTTP/2 streams)
 | 
			
		||||
// to ServeHTTP, which creates a new ServerTransport for each stream.
 | 
			
		||||
// serveUsingHandler blocks until conn closes.
 | 
			
		||||
//
 | 
			
		||||
// This codepath is only used when Server.TestingUseHandlerImpl has
 | 
			
		||||
// been configured. This lets the end2end tests exercise the ServeHTTP
 | 
			
		||||
// method as one of the environment types.
 | 
			
		||||
//
 | 
			
		||||
// conn is the *tls.Conn that's already been authenticated.
 | 
			
		||||
func (s *Server) serveUsingHandler(listenerAddr net.Addr, conn net.Conn) {
 | 
			
		||||
	if !s.addConn(conn) {
 | 
			
		||||
		conn.Close()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer s.removeConn(conn)
 | 
			
		||||
	connDone := make(chan struct{})
 | 
			
		||||
	hs := &http.Server{
 | 
			
		||||
		Handler: s,
 | 
			
		||||
		ConnState: func(c net.Conn, cs http.ConnState) {
 | 
			
		||||
			if cs == http.StateClosed {
 | 
			
		||||
				close(connDone)
 | 
			
		||||
			}
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	if err := http2.ConfigureServer(hs, &http2.Server{
 | 
			
		||||
		MaxConcurrentStreams: s.opts.maxConcurrentStreams,
 | 
			
		||||
	}); err != nil {
 | 
			
		||||
		grpclog.Fatalf("grpc: http2.ConfigureServer: %v", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	hs.Serve(&singleConnListener{addr: listenerAddr, conn: conn})
 | 
			
		||||
	<-connDone
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	st, err := transport.NewServerHandlerTransport(w, r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(w, err.Error(), http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if !s.addConn(st) {
 | 
			
		||||
		st.Close()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer s.removeConn(st)
 | 
			
		||||
	s.serveStreams(st)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled.
 | 
			
		||||
// If tracing is not enabled, it returns nil.
 | 
			
		||||
func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) {
 | 
			
		||||
| 
						 | 
				
			
			@ -317,21 +397,21 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
 | 
			
		|||
	return trInfo
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) addConn(st transport.ServerTransport) bool {
 | 
			
		||||
func (s *Server) addConn(c io.Closer) bool {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
	if s.conns == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	s.conns[st] = true
 | 
			
		||||
	s.conns[c] = true
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) removeConn(st transport.ServerTransport) {
 | 
			
		||||
func (s *Server) removeConn(c io.Closer) {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
	if s.conns != nil {
 | 
			
		||||
		delete(s.conns, st)
 | 
			
		||||
		delete(s.conns, c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -606,12 +686,14 @@ func (s *Server) Stop() {
 | 
			
		|||
	cs := s.conns
 | 
			
		||||
	s.conns = nil
 | 
			
		||||
	s.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	for lis := range listeners {
 | 
			
		||||
		lis.Close()
 | 
			
		||||
	}
 | 
			
		||||
	for c := range cs {
 | 
			
		||||
		c.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	if s.events != nil {
 | 
			
		||||
		s.events.Finish()
 | 
			
		||||
| 
						 | 
				
			
			@ -621,16 +703,24 @@ func (s *Server) Stop() {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
// TestingCloseConns closes all exiting transports but keeps s.lis accepting new
 | 
			
		||||
// connections. This is for test only now.
 | 
			
		||||
// connections.
 | 
			
		||||
// This is only for tests and is subject to removal.
 | 
			
		||||
func (s *Server) TestingCloseConns() {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	for c := range s.conns {
 | 
			
		||||
		c.Close()
 | 
			
		||||
		delete(s.conns, c)
 | 
			
		||||
	}
 | 
			
		||||
	s.conns = make(map[transport.ServerTransport]bool)
 | 
			
		||||
	s.mu.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TestingUseHandlerImpl enables the http.Handler-based server implementation.
 | 
			
		||||
// It must be called before Serve and requires TLS credentials.
 | 
			
		||||
// This is only for tests and is subject to removal.
 | 
			
		||||
func (s *Server) TestingUseHandlerImpl() {
 | 
			
		||||
	s.opts.useHandlerImpl = true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SendHeader sends header metadata. It may be called at most once from a unary
 | 
			
		||||
// RPC handler. The ctx is the RPC handler's Context or one derived from it.
 | 
			
		||||
func SendHeader(ctx context.Context, md metadata.MD) error {
 | 
			
		||||
| 
						 | 
				
			
			@ -661,3 +751,30 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
 | 
			
		|||
	}
 | 
			
		||||
	return stream.SetTrailer(md)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// singleConnListener is a net.Listener that yields a single conn.
 | 
			
		||||
type singleConnListener struct {
 | 
			
		||||
	mu   sync.Mutex
 | 
			
		||||
	addr net.Addr
 | 
			
		||||
	conn net.Conn // nil if done
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ln *singleConnListener) Addr() net.Addr { return ln.addr }
 | 
			
		||||
 | 
			
		||||
func (ln *singleConnListener) Close() error {
 | 
			
		||||
	ln.mu.Lock()
 | 
			
		||||
	defer ln.mu.Unlock()
 | 
			
		||||
	ln.conn = nil
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ln *singleConnListener) Accept() (net.Conn, error) {
 | 
			
		||||
	ln.mu.Lock()
 | 
			
		||||
	defer ln.mu.Unlock()
 | 
			
		||||
	c := ln.conn
 | 
			
		||||
	if c == nil {
 | 
			
		||||
		return nil, io.EOF
 | 
			
		||||
	}
 | 
			
		||||
	ln.conn = nil
 | 
			
		||||
	return c, nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -333,6 +333,7 @@ type env struct {
 | 
			
		|||
	network     string // The type of network such as tcp, unix, etc.
 | 
			
		||||
	dialer      func(addr string, timeout time.Duration) (net.Conn, error)
 | 
			
		||||
	security    string // The security protocol such as TLS, SSH, etc.
 | 
			
		||||
	httpHandler bool   // whether to use the http.Handler ServerTransport; requires TLS
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e env) runnable() bool {
 | 
			
		||||
| 
						 | 
				
			
			@ -347,10 +348,11 @@ var (
 | 
			
		|||
	tcpTLSEnv    = env{name: "tcp-tls", network: "tcp", security: "tls"}
 | 
			
		||||
	unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer}
 | 
			
		||||
	unixTLSEnv   = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"}
 | 
			
		||||
	allEnv       = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv}
 | 
			
		||||
	handlerEnv   = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true}
 | 
			
		||||
	allEnv       = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', or 'unix-tls' to only run the tests for that environment. Empty means all.")
 | 
			
		||||
var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', 'unix-tls', or 'handler-tls' to only run the tests for that environment. Empty means all.")
 | 
			
		||||
 | 
			
		||||
func listTestEnv() (envs []env) {
 | 
			
		||||
	if *onlyEnv != "" {
 | 
			
		||||
| 
						 | 
				
			
			@ -393,6 +395,9 @@ func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream u
 | 
			
		|||
		sopts = append(sopts, grpc.Creds(creds))
 | 
			
		||||
	}
 | 
			
		||||
	s = grpc.NewServer(sopts...)
 | 
			
		||||
	if e.httpHandler {
 | 
			
		||||
		s.TestingUseHandlerImpl()
 | 
			
		||||
	}
 | 
			
		||||
	if hs != nil {
 | 
			
		||||
		healthpb.RegisterHealthServer(s, hs)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -720,7 +725,7 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
 | 
			
		|||
		t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
 | 
			
		||||
	}
 | 
			
		||||
	if !reflect.DeepEqual(trailer, testTrailerMetadata) {
 | 
			
		||||
		t.Fatalf("Received trailer metadata %v, want %v", trailer, testMetadata)
 | 
			
		||||
		t.Fatalf("Received trailer metadata %v, want %v", trailer, testTrailerMetadata)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1030,11 +1035,13 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
 | 
			
		|||
		if e.security == "tls" {
 | 
			
		||||
			delete(headerMD, "transport_security_type")
 | 
			
		||||
		}
 | 
			
		||||
		delete(headerMD, "trailer") // ignore if present
 | 
			
		||||
		if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
 | 
			
		||||
			t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
 | 
			
		||||
		}
 | 
			
		||||
		// test the cached value.
 | 
			
		||||
		headerMD, err = stream.Header()
 | 
			
		||||
		delete(headerMD, "trailer") // ignore if present
 | 
			
		||||
		if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
 | 
			
		||||
			t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,329 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2016, Google Inc.
 | 
			
		||||
 * All rights reserved.
 | 
			
		||||
 *
 | 
			
		||||
 * Redistribution and use in source and binary forms, with or without
 | 
			
		||||
 * modification, are permitted provided that the following conditions are
 | 
			
		||||
 * met:
 | 
			
		||||
 *
 | 
			
		||||
 *     * Redistributions of source code must retain the above copyright
 | 
			
		||||
 * notice, this list of conditions and the following disclaimer.
 | 
			
		||||
 *     * Redistributions in binary form must reproduce the above
 | 
			
		||||
 * copyright notice, this list of conditions and the following disclaimer
 | 
			
		||||
 * in the documentation and/or other materials provided with the
 | 
			
		||||
 * distribution.
 | 
			
		||||
 *     * Neither the name of Google Inc. nor the names of its
 | 
			
		||||
 * contributors may be used to endorse or promote products derived from
 | 
			
		||||
 * this software without specific prior written permission.
 | 
			
		||||
 *
 | 
			
		||||
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 | 
			
		||||
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 | 
			
		||||
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 | 
			
		||||
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 | 
			
		||||
 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 | 
			
		||||
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 | 
			
		||||
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 | 
			
		||||
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 | 
			
		||||
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 | 
			
		||||
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | 
			
		||||
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
// This file is the implementation of a gRPC server using HTTP/2 which
 | 
			
		||||
// uses the standard Go http2 Server implementation (via the
 | 
			
		||||
// http.Handler interface), rather than speaking low-level HTTP/2
 | 
			
		||||
// frames itself. It is the implementation of *grpc.Server.ServeHTTP.
 | 
			
		||||
 | 
			
		||||
package transport
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
	"golang.org/x/net/http2"
 | 
			
		||||
	"google.golang.org/grpc/codes"
 | 
			
		||||
	"google.golang.org/grpc/credentials"
 | 
			
		||||
	"google.golang.org/grpc/metadata"
 | 
			
		||||
	"google.golang.org/grpc/peer"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// NewServerHandlerTransport returns a ServerTransport handling gRPC
 | 
			
		||||
// from inside an http.Handler. It requires that the http Server
 | 
			
		||||
// supports HTTP/2.
 | 
			
		||||
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) {
 | 
			
		||||
	if r.ProtoMajor != 2 {
 | 
			
		||||
		return nil, errors.New("gRPC requires HTTP/2")
 | 
			
		||||
	}
 | 
			
		||||
	if r.Method != "POST" {
 | 
			
		||||
		return nil, errors.New("invalid gRPC request method")
 | 
			
		||||
	}
 | 
			
		||||
	if !strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
 | 
			
		||||
		return nil, errors.New("invalid gRPC request content-type")
 | 
			
		||||
	}
 | 
			
		||||
	if _, ok := w.(http.Flusher); !ok {
 | 
			
		||||
		return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher")
 | 
			
		||||
	}
 | 
			
		||||
	if _, ok := w.(http.CloseNotifier); !ok {
 | 
			
		||||
		return nil, errors.New("gRPC requires a ResponseWriter supporting http.CloseNotifier")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	st := &serverHandlerTransport{
 | 
			
		||||
		rw:          w,
 | 
			
		||||
		req:         r,
 | 
			
		||||
		closedCh:    make(chan struct{}),
 | 
			
		||||
		wroteStatus: make(chan struct{}),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if v := r.Header.Get("grpc-timeout"); v != "" {
 | 
			
		||||
		to, err := timeoutDecode(v)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		st.timeoutSet = true
 | 
			
		||||
		st.timeout = to
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var metakv []string
 | 
			
		||||
	for k, vv := range r.Header {
 | 
			
		||||
		k = strings.ToLower(k)
 | 
			
		||||
		if isReservedHeader(k) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		for _, v := range vv {
 | 
			
		||||
			if k == "user-agent" {
 | 
			
		||||
				// user-agent is special. Copying logic of http_util.go.
 | 
			
		||||
				if i := strings.LastIndex(v, " "); i == -1 {
 | 
			
		||||
					// There is no application user agent string being set
 | 
			
		||||
					continue
 | 
			
		||||
				} else {
 | 
			
		||||
					v = v[:i]
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			metakv = append(metakv, k, v)
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	st.headerMD = metadata.Pairs(metakv...)
 | 
			
		||||
 | 
			
		||||
	return st, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// serverHandlerTransport is an implementation of ServerTransport
 | 
			
		||||
// which replies to exactly one gRPC request (exactly one HTTP request),
 | 
			
		||||
// using the net/http.Handler interface. This http.Handler is guranteed
 | 
			
		||||
// at this point to be speaking over HTTP/2, so it's able to speak valid
 | 
			
		||||
// gRPC.
 | 
			
		||||
type serverHandlerTransport struct {
 | 
			
		||||
	rw               http.ResponseWriter
 | 
			
		||||
	req              *http.Request
 | 
			
		||||
	timeoutSet       bool
 | 
			
		||||
	timeout          time.Duration
 | 
			
		||||
	didCommonHeaders bool
 | 
			
		||||
 | 
			
		||||
	headerMD metadata.MD
 | 
			
		||||
 | 
			
		||||
	closeOnce sync.Once
 | 
			
		||||
	closedCh  chan struct{} // closed on Close
 | 
			
		||||
 | 
			
		||||
	wroteStatus chan struct{} // closed on WriteStatus
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ht *serverHandlerTransport) Close() error {
 | 
			
		||||
	ht.closeOnce.Do(ht.closeCloseChanOnce)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) }
 | 
			
		||||
 | 
			
		||||
func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) }
 | 
			
		||||
 | 
			
		||||
// strAddr is a net.Addr backed by either a TCP "ip:port" string, or
 | 
			
		||||
// the empty string if unknown.
 | 
			
		||||
type strAddr string
 | 
			
		||||
 | 
			
		||||
func (a strAddr) Network() string {
 | 
			
		||||
	if a != "" {
 | 
			
		||||
		// Per the documentation on net/http.Request.RemoteAddr, if this is
 | 
			
		||||
		// set, it's set to the IP:port of the peer (hence, TCP):
 | 
			
		||||
		// https://golang.org/pkg/net/http/#Request
 | 
			
		||||
		//
 | 
			
		||||
		// If we want to support Unix sockets later, we can
 | 
			
		||||
		// add our own grpc-specific convention within the
 | 
			
		||||
		// grpc codebase to set RemoteAddr to a different
 | 
			
		||||
		// format, or probably better: we can attach it to the
 | 
			
		||||
		// context and use that from serverHandlerTransport.RemoteAddr.
 | 
			
		||||
		return "tcp"
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a strAddr) String() string { return string(a) }
 | 
			
		||||
 | 
			
		||||
func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
 | 
			
		||||
	ht.writeCommonHeaders(s)
 | 
			
		||||
 | 
			
		||||
	// And flush, in case no header or body has been sent yet.
 | 
			
		||||
	// This forces a separation of headers and trailers if this is the
 | 
			
		||||
	// first call (for example, in end2end tests's TestNoService).
 | 
			
		||||
	ht.rw.(http.Flusher).Flush()
 | 
			
		||||
 | 
			
		||||
	h := ht.rw.Header()
 | 
			
		||||
	h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
 | 
			
		||||
	if statusDesc != "" {
 | 
			
		||||
		h.Set("Grpc-Message", statusDesc)
 | 
			
		||||
	}
 | 
			
		||||
	if md := s.Trailer(); len(md) > 0 {
 | 
			
		||||
		for k, vv := range md {
 | 
			
		||||
			for _, v := range vv {
 | 
			
		||||
				// http2 ResponseWriter mechanism to
 | 
			
		||||
				// send undeclared Trailers after the
 | 
			
		||||
				// headers have possibly been written.
 | 
			
		||||
				h.Add(http2.TrailerPrefix+k, v)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	close(ht.wroteStatus)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// writeCommonHeaders sets common headers on the first write
 | 
			
		||||
// call (Write, WriteHeader, or WriteStatus).
 | 
			
		||||
func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
 | 
			
		||||
	if ht.didCommonHeaders {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	ht.didCommonHeaders = true
 | 
			
		||||
 | 
			
		||||
	h := ht.rw.Header()
 | 
			
		||||
	h["Date"] = nil // suppress Date to make tests happy; TODO: restore
 | 
			
		||||
	h.Set("Content-Type", "application/grpc")
 | 
			
		||||
 | 
			
		||||
	// Predeclare trailers we'll set later in WriteStatus (after the body).
 | 
			
		||||
	// This is a SHOULD in the HTTP RFC, and the way you add (known)
 | 
			
		||||
	// Trailers per the net/http.ResponseWriter contract.
 | 
			
		||||
	// See https://golang.org/pkg/net/http/#ResponseWriter
 | 
			
		||||
	// and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
 | 
			
		||||
	h.Add("Trailer", "Grpc-Status")
 | 
			
		||||
	h.Add("Trailer", "Grpc-Message")
 | 
			
		||||
 | 
			
		||||
	if s.sendCompress != "" {
 | 
			
		||||
		h.Set("Grpc-Encoding", s.sendCompress)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
 | 
			
		||||
	ht.writeCommonHeaders(s)
 | 
			
		||||
	ht.rw.Write(data)
 | 
			
		||||
	if !opts.Delay {
 | 
			
		||||
		ht.rw.(http.Flusher).Flush()
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
 | 
			
		||||
	ht.writeCommonHeaders(s)
 | 
			
		||||
	h := ht.rw.Header()
 | 
			
		||||
	for k, vv := range md {
 | 
			
		||||
		for _, v := range vv {
 | 
			
		||||
			h.Add(k, v)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	ht.rw.WriteHeader(200)
 | 
			
		||||
	ht.rw.(http.Flusher).Flush()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) {
 | 
			
		||||
	// With this transport type there will be exactly 1 stream: this HTTP request.
 | 
			
		||||
 | 
			
		||||
	var ctx context.Context
 | 
			
		||||
	var cancel context.CancelFunc
 | 
			
		||||
	if ht.timeoutSet {
 | 
			
		||||
		ctx, cancel = context.WithTimeout(context.Background(), ht.timeout)
 | 
			
		||||
	} else {
 | 
			
		||||
		ctx, cancel = context.WithCancel(context.Background())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// clientGone receives a single value if peer is gone, either
 | 
			
		||||
	// because the underlying connection is dead or because the
 | 
			
		||||
	// peer sends an http2 RST_STREAM.
 | 
			
		||||
	clientGone := ht.rw.(http.CloseNotifier).CloseNotify()
 | 
			
		||||
	go func() {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ht.closedCh:
 | 
			
		||||
		case <-clientGone:
 | 
			
		||||
		}
 | 
			
		||||
		cancel()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	req := ht.req
 | 
			
		||||
 | 
			
		||||
	s := &Stream{
 | 
			
		||||
		id:            0,            // irrelevant
 | 
			
		||||
		windowHandler: func(int) {}, // nothing
 | 
			
		||||
		cancel:        cancel,
 | 
			
		||||
		buf:           newRecvBuffer(),
 | 
			
		||||
		st:            ht,
 | 
			
		||||
		method:        req.URL.Path,
 | 
			
		||||
		recvCompress:  req.Header.Get("grpc-encoding"),
 | 
			
		||||
	}
 | 
			
		||||
	pr := &peer.Peer{
 | 
			
		||||
		Addr: ht.RemoteAddr(),
 | 
			
		||||
	}
 | 
			
		||||
	if req.TLS != nil {
 | 
			
		||||
		pr.AuthInfo = credentials.TLSInfo{*req.TLS}
 | 
			
		||||
	}
 | 
			
		||||
	ctx = metadata.NewContext(ctx, ht.headerMD)
 | 
			
		||||
	ctx = peer.NewContext(ctx, pr)
 | 
			
		||||
	s.ctx = newContextWithStream(ctx, s)
 | 
			
		||||
	s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
 | 
			
		||||
 | 
			
		||||
	// requestOver is closed when either the request's context is done
 | 
			
		||||
	// or the status has been written via WriteStatus.
 | 
			
		||||
	requestOver := make(chan struct{})
 | 
			
		||||
 | 
			
		||||
	// readerDone is closed when the Body.Read-ing goroutine exits.
 | 
			
		||||
	readerDone := make(chan struct{})
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer close(readerDone)
 | 
			
		||||
		for {
 | 
			
		||||
			buf := make([]byte, 1024) // TODO: minimize garbage, optimize recvBuffer code/ownership
 | 
			
		||||
			n, err := req.Body.Read(buf)
 | 
			
		||||
			select {
 | 
			
		||||
			case <-requestOver:
 | 
			
		||||
				return
 | 
			
		||||
			default:
 | 
			
		||||
			}
 | 
			
		||||
			if n > 0 {
 | 
			
		||||
				s.buf.put(&recvMsg{data: buf[:n]})
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				s.buf.put(&recvMsg{err: err})
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// runStream is provided by the *grpc.Server.serveStreams.
 | 
			
		||||
	// It starts a goroutine handling s and exits immediately.
 | 
			
		||||
	runStream(s)
 | 
			
		||||
 | 
			
		||||
	// Wait for the stream to be done. It is considered done when
 | 
			
		||||
	// either its context is done, or we've written its status.
 | 
			
		||||
	select {
 | 
			
		||||
	case <-ctx.Done():
 | 
			
		||||
	case <-ht.wroteStatus:
 | 
			
		||||
	}
 | 
			
		||||
	close(requestOver)
 | 
			
		||||
 | 
			
		||||
	// Wait for reading goroutine to finish.
 | 
			
		||||
	req.Body.Close()
 | 
			
		||||
	<-readerDone
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,386 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2016, Google Inc.
 | 
			
		||||
 * All rights reserved.
 | 
			
		||||
 *
 | 
			
		||||
 * Redistribution and use in source and binary forms, with or without
 | 
			
		||||
 * modification, are permitted provided that the following conditions are
 | 
			
		||||
 * met:
 | 
			
		||||
 *
 | 
			
		||||
 *     * Redistributions of source code must retain the above copyright
 | 
			
		||||
 * notice, this list of conditions and the following disclaimer.
 | 
			
		||||
 *     * Redistributions in binary form must reproduce the above
 | 
			
		||||
 * copyright notice, this list of conditions and the following disclaimer
 | 
			
		||||
 * in the documentation and/or other materials provided with the
 | 
			
		||||
 * distribution.
 | 
			
		||||
 *     * Neither the name of Google Inc. nor the names of its
 | 
			
		||||
 * contributors may be used to endorse or promote products derived from
 | 
			
		||||
 * this software without specific prior written permission.
 | 
			
		||||
 *
 | 
			
		||||
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 | 
			
		||||
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 | 
			
		||||
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 | 
			
		||||
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 | 
			
		||||
 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 | 
			
		||||
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 | 
			
		||||
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 | 
			
		||||
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 | 
			
		||||
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 | 
			
		||||
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | 
			
		||||
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package transport
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
	"google.golang.org/grpc/codes"
 | 
			
		||||
	"google.golang.org/grpc/metadata"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
 | 
			
		||||
	type testCase struct {
 | 
			
		||||
		name    string
 | 
			
		||||
		req     *http.Request
 | 
			
		||||
		wantErr string
 | 
			
		||||
		modrw   func(http.ResponseWriter) http.ResponseWriter
 | 
			
		||||
		check   func(*serverHandlerTransport, *testCase) error
 | 
			
		||||
	}
 | 
			
		||||
	tests := []testCase{
 | 
			
		||||
		{
 | 
			
		||||
			name: "http/1.1",
 | 
			
		||||
			req: &http.Request{
 | 
			
		||||
				ProtoMajor: 1,
 | 
			
		||||
				ProtoMinor: 1,
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: "gRPC requires HTTP/2",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "bad method",
 | 
			
		||||
			req: &http.Request{
 | 
			
		||||
				ProtoMajor: 2,
 | 
			
		||||
				Method:     "GET",
 | 
			
		||||
				Header:     http.Header{},
 | 
			
		||||
				RequestURI: "/",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: "invalid gRPC request method",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "bad content type",
 | 
			
		||||
			req: &http.Request{
 | 
			
		||||
				ProtoMajor: 2,
 | 
			
		||||
				Method:     "POST",
 | 
			
		||||
				Header: http.Header{
 | 
			
		||||
					"Content-Type": {"application/foo"},
 | 
			
		||||
				},
 | 
			
		||||
				RequestURI: "/service/foo.bar",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: "invalid gRPC request content-type",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "not flusher",
 | 
			
		||||
			req: &http.Request{
 | 
			
		||||
				ProtoMajor: 2,
 | 
			
		||||
				Method:     "POST",
 | 
			
		||||
				Header: http.Header{
 | 
			
		||||
					"Content-Type": {"application/grpc"},
 | 
			
		||||
				},
 | 
			
		||||
				RequestURI: "/service/foo.bar",
 | 
			
		||||
			},
 | 
			
		||||
			modrw: func(w http.ResponseWriter) http.ResponseWriter {
 | 
			
		||||
				// Return w without its Flush method
 | 
			
		||||
				type onlyCloseNotifier interface {
 | 
			
		||||
					http.ResponseWriter
 | 
			
		||||
					http.CloseNotifier
 | 
			
		||||
				}
 | 
			
		||||
				return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: "gRPC requires a ResponseWriter supporting http.Flusher",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "not closenotifier",
 | 
			
		||||
			req: &http.Request{
 | 
			
		||||
				ProtoMajor: 2,
 | 
			
		||||
				Method:     "POST",
 | 
			
		||||
				Header: http.Header{
 | 
			
		||||
					"Content-Type": {"application/grpc"},
 | 
			
		||||
				},
 | 
			
		||||
				RequestURI: "/service/foo.bar",
 | 
			
		||||
			},
 | 
			
		||||
			modrw: func(w http.ResponseWriter) http.ResponseWriter {
 | 
			
		||||
				// Return w without its CloseNotify method
 | 
			
		||||
				type onlyFlusher interface {
 | 
			
		||||
					http.ResponseWriter
 | 
			
		||||
					http.Flusher
 | 
			
		||||
				}
 | 
			
		||||
				return struct{ onlyFlusher }{w.(onlyFlusher)}
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: "gRPC requires a ResponseWriter supporting http.CloseNotifier",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "valid",
 | 
			
		||||
			req: &http.Request{
 | 
			
		||||
				ProtoMajor: 2,
 | 
			
		||||
				Method:     "POST",
 | 
			
		||||
				Header: http.Header{
 | 
			
		||||
					"Content-Type": {"application/grpc"},
 | 
			
		||||
				},
 | 
			
		||||
				URL: &url.URL{
 | 
			
		||||
					Path: "/service/foo.bar",
 | 
			
		||||
				},
 | 
			
		||||
				RequestURI: "/service/foo.bar",
 | 
			
		||||
			},
 | 
			
		||||
			check: func(t *serverHandlerTransport, tt *testCase) error {
 | 
			
		||||
				if t.req != tt.req {
 | 
			
		||||
					return fmt.Errorf("t.req = %p; want %p", t.req, tt.req)
 | 
			
		||||
				}
 | 
			
		||||
				if t.rw == nil {
 | 
			
		||||
					return errors.New("t.rw = nil; want non-nil")
 | 
			
		||||
				}
 | 
			
		||||
				return nil
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "with timeout",
 | 
			
		||||
			req: &http.Request{
 | 
			
		||||
				ProtoMajor: 2,
 | 
			
		||||
				Method:     "POST",
 | 
			
		||||
				Header: http.Header{
 | 
			
		||||
					"Content-Type": []string{"application/grpc"},
 | 
			
		||||
					"Grpc-Timeout": {"200m"},
 | 
			
		||||
				},
 | 
			
		||||
				URL: &url.URL{
 | 
			
		||||
					Path: "/service/foo.bar",
 | 
			
		||||
				},
 | 
			
		||||
				RequestURI: "/service/foo.bar",
 | 
			
		||||
			},
 | 
			
		||||
			check: func(t *serverHandlerTransport, tt *testCase) error {
 | 
			
		||||
				if !t.timeoutSet {
 | 
			
		||||
					return errors.New("timeout not set")
 | 
			
		||||
				}
 | 
			
		||||
				if want := 200 * time.Millisecond; t.timeout != want {
 | 
			
		||||
					return fmt.Errorf("timeout = %v; want %v", t.timeout, want)
 | 
			
		||||
				}
 | 
			
		||||
				return nil
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "with bad timeout",
 | 
			
		||||
			req: &http.Request{
 | 
			
		||||
				ProtoMajor: 2,
 | 
			
		||||
				Method:     "POST",
 | 
			
		||||
				Header: http.Header{
 | 
			
		||||
					"Content-Type": []string{"application/grpc"},
 | 
			
		||||
					"Grpc-Timeout": {"tomorrow"},
 | 
			
		||||
				},
 | 
			
		||||
				URL: &url.URL{
 | 
			
		||||
					Path: "/service/foo.bar",
 | 
			
		||||
				},
 | 
			
		||||
				RequestURI: "/service/foo.bar",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: `stream error: code = 13 desc = "malformed time-out: transport: timeout unit is not recognized: \"tomorrow\""`,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "with metadata",
 | 
			
		||||
			req: &http.Request{
 | 
			
		||||
				ProtoMajor: 2,
 | 
			
		||||
				Method:     "POST",
 | 
			
		||||
				Header: http.Header{
 | 
			
		||||
					"Content-Type": []string{"application/grpc"},
 | 
			
		||||
					"meta-foo":     {"foo-val"},
 | 
			
		||||
					"meta-bar":     {"bar-val1", "bar-val2"},
 | 
			
		||||
					"user-agent":   {"x/y a/b"},
 | 
			
		||||
				},
 | 
			
		||||
				URL: &url.URL{
 | 
			
		||||
					Path: "/service/foo.bar",
 | 
			
		||||
				},
 | 
			
		||||
				RequestURI: "/service/foo.bar",
 | 
			
		||||
			},
 | 
			
		||||
			check: func(ht *serverHandlerTransport, tt *testCase) error {
 | 
			
		||||
				want := metadata.MD{
 | 
			
		||||
					"meta-bar":   {"bar-val1", "bar-val2"},
 | 
			
		||||
					"user-agent": {"x/y"},
 | 
			
		||||
					"meta-foo":   {"foo-val"},
 | 
			
		||||
				}
 | 
			
		||||
				if !reflect.DeepEqual(ht.headerMD, want) {
 | 
			
		||||
					return fmt.Errorf("metdata = %#v; want %#v", ht.headerMD, want)
 | 
			
		||||
				}
 | 
			
		||||
				return nil
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		rw := newTestHandlerResponseWriter()
 | 
			
		||||
		if tt.modrw != nil {
 | 
			
		||||
			rw = tt.modrw(rw)
 | 
			
		||||
		}
 | 
			
		||||
		got, gotErr := NewServerHandlerTransport(rw, tt.req)
 | 
			
		||||
		if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
 | 
			
		||||
			t.Errorf("%s: error = %v; want %q", tt.name, gotErr, tt.wantErr)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if gotErr != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if tt.check != nil {
 | 
			
		||||
			if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil {
 | 
			
		||||
				t.Errorf("%s: %v", tt.name, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type testHandlerResponseWriter struct {
 | 
			
		||||
	*httptest.ResponseRecorder
 | 
			
		||||
	closeNotify chan bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (w testHandlerResponseWriter) CloseNotify() <-chan bool { return w.closeNotify }
 | 
			
		||||
func (w testHandlerResponseWriter) Flush()                   {}
 | 
			
		||||
 | 
			
		||||
func newTestHandlerResponseWriter() http.ResponseWriter {
 | 
			
		||||
	return testHandlerResponseWriter{
 | 
			
		||||
		ResponseRecorder: httptest.NewRecorder(),
 | 
			
		||||
		closeNotify:      make(chan bool, 1),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type handleStreamTest struct {
 | 
			
		||||
	t     *testing.T
 | 
			
		||||
	bodyw *io.PipeWriter
 | 
			
		||||
	req   *http.Request
 | 
			
		||||
	rw    testHandlerResponseWriter
 | 
			
		||||
	ht    *serverHandlerTransport
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newHandleStreamTest(t *testing.T) *handleStreamTest {
 | 
			
		||||
	bodyr, bodyw := io.Pipe()
 | 
			
		||||
	req := &http.Request{
 | 
			
		||||
		ProtoMajor: 2,
 | 
			
		||||
		Method:     "POST",
 | 
			
		||||
		Header: http.Header{
 | 
			
		||||
			"Content-Type": {"application/grpc"},
 | 
			
		||||
		},
 | 
			
		||||
		URL: &url.URL{
 | 
			
		||||
			Path: "/service/foo.bar",
 | 
			
		||||
		},
 | 
			
		||||
		RequestURI: "/service/foo.bar",
 | 
			
		||||
		Body:       bodyr,
 | 
			
		||||
	}
 | 
			
		||||
	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
 | 
			
		||||
	ht, err := NewServerHandlerTransport(rw, req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	return &handleStreamTest{
 | 
			
		||||
		t:     t,
 | 
			
		||||
		bodyw: bodyw,
 | 
			
		||||
		ht:    ht.(*serverHandlerTransport),
 | 
			
		||||
		rw:    rw,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHandlerTransport_HandleStreams(t *testing.T) {
 | 
			
		||||
	st := newHandleStreamTest(t)
 | 
			
		||||
	st.ht.HandleStreams(func(s *Stream) {
 | 
			
		||||
		if want := "/service/foo.bar"; s.method != want {
 | 
			
		||||
			t.Errorf("stream method = %q; want %q", s.method, want)
 | 
			
		||||
		}
 | 
			
		||||
		st.bodyw.Close() // no body
 | 
			
		||||
		st.ht.WriteStatus(s, codes.OK, "")
 | 
			
		||||
	})
 | 
			
		||||
	wantHeader := http.Header{
 | 
			
		||||
		"Date":         nil,
 | 
			
		||||
		"Content-Type": {"application/grpc"},
 | 
			
		||||
		"Trailer":      {"Grpc-Status", "Grpc-Message"},
 | 
			
		||||
		"Grpc-Status":  {"0"},
 | 
			
		||||
	}
 | 
			
		||||
	if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
 | 
			
		||||
		t.Errorf("Header+Trailer Map: %#v; want %#v", st.rw.HeaderMap, wantHeader)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests that codes.Unimplemented will close the body, per comment in handler_server.go.
 | 
			
		||||
func TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) {
 | 
			
		||||
	handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tests that codes.InvalidArgument will close the body, per comment in handler_server.go.
 | 
			
		||||
func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
 | 
			
		||||
	handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
 | 
			
		||||
	st := newHandleStreamTest(t)
 | 
			
		||||
	st.ht.HandleStreams(func(s *Stream) {
 | 
			
		||||
		st.ht.WriteStatus(s, statusCode, msg)
 | 
			
		||||
	})
 | 
			
		||||
	wantHeader := http.Header{
 | 
			
		||||
		"Date":         nil,
 | 
			
		||||
		"Content-Type": {"application/grpc"},
 | 
			
		||||
		"Trailer":      {"Grpc-Status", "Grpc-Message"},
 | 
			
		||||
		"Grpc-Status":  {fmt.Sprint(uint32(statusCode))},
 | 
			
		||||
		"Grpc-Message": {msg},
 | 
			
		||||
	}
 | 
			
		||||
	if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
 | 
			
		||||
		t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
 | 
			
		||||
	bodyr, bodyw := io.Pipe()
 | 
			
		||||
	req := &http.Request{
 | 
			
		||||
		ProtoMajor: 2,
 | 
			
		||||
		Method:     "POST",
 | 
			
		||||
		Header: http.Header{
 | 
			
		||||
			"Content-Type": {"application/grpc"},
 | 
			
		||||
			"Grpc-Timeout": {"200m"},
 | 
			
		||||
		},
 | 
			
		||||
		URL: &url.URL{
 | 
			
		||||
			Path: "/service/foo.bar",
 | 
			
		||||
		},
 | 
			
		||||
		RequestURI: "/service/foo.bar",
 | 
			
		||||
		Body:       bodyr,
 | 
			
		||||
	}
 | 
			
		||||
	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
 | 
			
		||||
	ht, err := NewServerHandlerTransport(rw, req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	ht.HandleStreams(func(s *Stream) {
 | 
			
		||||
		defer bodyw.Close()
 | 
			
		||||
		select {
 | 
			
		||||
		case <-s.ctx.Done():
 | 
			
		||||
		case <-time.After(5 * time.Second):
 | 
			
		||||
			t.Errorf("timeout waiting for ctx.Done")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		err := s.ctx.Err()
 | 
			
		||||
		if err != context.DeadlineExceeded {
 | 
			
		||||
			t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		ht.WriteStatus(s, codes.DeadlineExceeded, "too slow")
 | 
			
		||||
	})
 | 
			
		||||
	wantHeader := http.Header{
 | 
			
		||||
		"Date":         nil,
 | 
			
		||||
		"Content-Type": {"application/grpc"},
 | 
			
		||||
		"Trailer":      {"Grpc-Status", "Grpc-Message"},
 | 
			
		||||
		"Grpc-Status":  {"4"},
 | 
			
		||||
		"Grpc-Message": {"too slow"},
 | 
			
		||||
	}
 | 
			
		||||
	if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
 | 
			
		||||
		t.Errorf("Header+Trailer Map: %#v; want %#v", rw.HeaderMap, wantHeader)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue