mirror of https://github.com/grpc/grpc-go.git
				
				
				
			transport/http2: use HTTP 400 for bad requests instead of 500 (#5804)
This commit is contained in:
		
							parent
							
								
									5003029eb6
								
							
						
					
					
						commit
						2f413c4548
					
				| 
						 | 
				
			
			@ -46,24 +46,32 @@ import (
 | 
			
		|||
	"google.golang.org/grpc/status"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// NewServerHandlerTransport returns a ServerTransport handling gRPC
 | 
			
		||||
// from inside an http.Handler. It requires that the http Server
 | 
			
		||||
// supports HTTP/2.
 | 
			
		||||
// NewServerHandlerTransport returns a ServerTransport handling gRPC from
 | 
			
		||||
// inside an http.Handler, or writes an HTTP error to w and returns an error.
 | 
			
		||||
// It requires that the http Server supports HTTP/2.
 | 
			
		||||
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler) (ServerTransport, error) {
 | 
			
		||||
	if r.ProtoMajor != 2 {
 | 
			
		||||
		return nil, errors.New("gRPC requires HTTP/2")
 | 
			
		||||
		msg := "gRPC requires HTTP/2"
 | 
			
		||||
		http.Error(w, msg, http.StatusBadRequest)
 | 
			
		||||
		return nil, errors.New(msg)
 | 
			
		||||
	}
 | 
			
		||||
	if r.Method != "POST" {
 | 
			
		||||
		return nil, errors.New("invalid gRPC request method")
 | 
			
		||||
		msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
 | 
			
		||||
		http.Error(w, msg, http.StatusBadRequest)
 | 
			
		||||
		return nil, errors.New(msg)
 | 
			
		||||
	}
 | 
			
		||||
	contentType := r.Header.Get("Content-Type")
 | 
			
		||||
	// TODO: do we assume contentType is lowercase? we did before
 | 
			
		||||
	contentSubtype, validContentType := grpcutil.ContentSubtype(contentType)
 | 
			
		||||
	if !validContentType {
 | 
			
		||||
		return nil, errors.New("invalid gRPC request content-type")
 | 
			
		||||
		msg := fmt.Sprintf("invalid gRPC request content-type %q", contentType)
 | 
			
		||||
		http.Error(w, msg, http.StatusBadRequest)
 | 
			
		||||
		return nil, errors.New(msg)
 | 
			
		||||
	}
 | 
			
		||||
	if _, ok := w.(http.Flusher); !ok {
 | 
			
		||||
		return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher")
 | 
			
		||||
		msg := "gRPC requires a ResponseWriter supporting http.Flusher"
 | 
			
		||||
		http.Error(w, msg, http.StatusInternalServerError)
 | 
			
		||||
		return nil, errors.New(msg)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	st := &serverHandlerTransport{
 | 
			
		||||
| 
						 | 
				
			
			@ -79,7 +87,9 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
 | 
			
		|||
	if v := r.Header.Get("grpc-timeout"); v != "" {
 | 
			
		||||
		to, err := decodeTimeout(v)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, status.Errorf(codes.Internal, "malformed time-out: %v", err)
 | 
			
		||||
			msg := fmt.Sprintf("malformed time-out: %v", err)
 | 
			
		||||
			http.Error(w, msg, http.StatusBadRequest)
 | 
			
		||||
			return nil, status.Error(codes.Internal, msg)
 | 
			
		||||
		}
 | 
			
		||||
		st.timeoutSet = true
 | 
			
		||||
		st.timeout = to
 | 
			
		||||
| 
						 | 
				
			
			@ -97,7 +107,9 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
 | 
			
		|||
		for _, v := range vv {
 | 
			
		||||
			v, err := decodeMetadataHeader(k, v)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, status.Errorf(codes.Internal, "malformed binary metadata: %v", err)
 | 
			
		||||
				msg := fmt.Sprintf("malformed binary metadata %q in header %q: %v", v, k, err)
 | 
			
		||||
				http.Error(w, msg, http.StatusBadRequest)
 | 
			
		||||
				return nil, status.Error(codes.Internal, msg)
 | 
			
		||||
			}
 | 
			
		||||
			metakv = append(metakv, k, v)
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -63,7 +63,7 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
 | 
			
		|||
				Method:     "GET",
 | 
			
		||||
				Header:     http.Header{},
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: "invalid gRPC request method",
 | 
			
		||||
			wantErr: `invalid gRPC request method "GET"`,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "bad content type",
 | 
			
		||||
| 
						 | 
				
			
			@ -74,7 +74,7 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
 | 
			
		|||
					"Content-Type": {"application/foo"},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: "invalid gRPC request content-type",
 | 
			
		||||
			wantErr: `invalid gRPC request content-type "application/foo"`,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "not flusher",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1008,7 +1008,8 @@ var _ http.Handler = (*Server)(nil)
 | 
			
		|||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		http.Error(w, err.Error(), http.StatusInternalServerError)
 | 
			
		||||
		// Errors returned from transport.NewServerHandlerTransport have
 | 
			
		||||
		// already been written to w.
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if !s.addConn(listenerAddressForServeHTTP, st) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue