diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index ff8f4db08..746423ff7 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -1140,15 +1140,30 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { if !ok { return } + endStream := frame.StreamEnded() atomic.StoreUint32(&s.bytesReceived, 1) - var state decodeState - if err := state.decodeHeader(frame); err != nil { - t.closeStream(s, err, true, http2.ErrCodeProtocol, status.New(codes.Internal, err.Error()), nil, false) - // Something wrong. Stops reading even when there is remaining. + initialHeader := atomic.SwapUint32(&s.headerDone, 1) == 0 + + if !initialHeader && !endStream { + // As specified by RFC 7540, a HEADERS frame (and associated CONTINUATION frames) can only appear + // at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set. + st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream") + t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false) + return + } + + state := &decodeState{ + serverSide: false, + ignoreContentType: !initialHeader, + } + // Initialize isGRPC value to be !initialHeader, since if a gRPC ResponseHeader has been received + // which indicates peer speaking gRPC, we are in gRPC mode. + state.data.isGRPC = !initialHeader + if err := state.decodeHeader(frame); err != nil { + t.closeStream(s, err, true, http2.ErrCodeProtocol, status.Convert(err), nil, endStream) return } - endStream := frame.StreamEnded() var isHeader bool defer func() { if t.statsHandler != nil { @@ -1167,29 +1182,30 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } } }() + // If headers haven't been received yet. - if atomic.SwapUint32(&s.headerDone, 1) == 0 { + if initialHeader { if !endStream { - // Headers frame is not actually a trailers-only frame. + // Headers frame is ResponseHeader. isHeader = true // These values can be set without any synchronization because // stream goroutine will read it only after seeing a closed // headerChan which we'll close after setting this. - s.recvCompress = state.encoding - if len(state.mdata) > 0 { - s.header = state.mdata + s.recvCompress = state.data.encoding + if len(state.data.mdata) > 0 { + s.header = state.data.mdata } - } else { - s.noHeaders = true + close(s.headerChan) + return } + // Headers frame is Trailers-only. + s.noHeaders = true close(s.headerChan) } - if !endStream { - return - } + // if client received END_STREAM from server while stream was still active, send RST_STREAM rst := s.getState() == streamActive - t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.mdata, true) + t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.data.mdata, true) } // reader runs as a separate goroutine in charge of reading data from network diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 2b996f641..19ff6edcc 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -286,7 +286,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err // operateHeader takes action on the decoded headers. func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (fatal bool) { streamID := frame.Header().StreamID - state := decodeState{serverSide: true} + state := &decodeState{ + serverSide: true, + ignoreContentType: false, + } if err := state.decodeHeader(frame); err != nil { if se, ok := status.FromError(err); ok { t.controlBuf.put(&cleanupStream{ @@ -305,16 +308,16 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( st: t, buf: buf, fc: &inFlow{limit: uint32(t.initialWindowSize)}, - recvCompress: state.encoding, - method: state.method, - contentSubtype: state.contentSubtype, + recvCompress: state.data.encoding, + method: state.data.method, + contentSubtype: state.data.contentSubtype, } if frame.StreamEnded() { // s is just created by the caller. No lock needed. s.state = streamReadDone } - if state.timeoutSet { - s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout) + if state.data.timeoutSet { + s.ctx, s.cancel = context.WithTimeout(t.ctx, state.data.timeout) } else { s.ctx, s.cancel = context.WithCancel(t.ctx) } @@ -327,19 +330,19 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } s.ctx = peer.NewContext(s.ctx, pr) // Attach the received metadata to the context. - if len(state.mdata) > 0 { - s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata) + if len(state.data.mdata) > 0 { + s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata) } - if state.statsTags != nil { - s.ctx = stats.SetIncomingTags(s.ctx, state.statsTags) + if state.data.statsTags != nil { + s.ctx = stats.SetIncomingTags(s.ctx, state.data.statsTags) } - if state.statsTrace != nil { - s.ctx = stats.SetIncomingTrace(s.ctx, state.statsTrace) + if state.data.statsTrace != nil { + s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace) } if t.inTapHandle != nil { var err error info := &tap.Info{ - FullMethodName: state.method, + FullMethodName: state.data.method, } s.ctx, err = t.inTapHandle(s.ctx, info) if err != nil { diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index 77a2cfaae..de0e7264b 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -78,7 +78,8 @@ var ( codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm, codes.PermissionDenied: http2.ErrCodeInadequateSecurity, } - httpStatusConvTab = map[int]codes.Code{ + // HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table. + HTTPStatusConvTab = map[int]codes.Code{ // 400 Bad Request - INTERNAL. http.StatusBadRequest: codes.Internal, // 401 Unauthorized - UNAUTHENTICATED. @@ -98,9 +99,7 @@ var ( } ) -// Records the states during HPACK decoding. Must be reset once the -// decoding of the entire headers are finished. -type decodeState struct { +type parsedHeaderData struct { encoding string // statusGen caches the stream status received from the trailer the server // sent. Client side only. Do not access directly. After all trailers are @@ -120,8 +119,38 @@ type decodeState struct { statsTags []byte statsTrace []byte contentSubtype string + + // isGRPC field indicates whether the peer is speaking gRPC (otherwise HTTP). + // + // We are in gRPC mode (peer speaking gRPC) if: + // * We are client side and have already received a HEADER frame that indicates gRPC peer. + // * The header contains valid a content-type, i.e. a string starts with "application/grpc" + // And we should handle error specific to gRPC. + // + // Otherwise (i.e. a content-type string starts without "application/grpc", or does not exist), we + // are in HTTP fallback mode, and should handle error specific to HTTP. + isGRPC bool + grpcErr error + httpErr error + contentTypeErr string +} + +// decodeState configures decoding criteria and records the decoded data. +type decodeState struct { // whether decoding on server side or not serverSide bool + // ignoreContentType indicates whether when processing the HEADERS frame, ignoring checking the + // content-type is grpc or not. + // + // Trailers (after headers) should not have a content-type. And thus we will ignore checking the + // content-type. + // + // For server, this field is always false. + ignoreContentType bool + + // Records the states during HPACK decoding. It will be filled with info parsed from HTTP HEADERS + // frame once decodeHeader function has been invoked and returned. + data parsedHeaderData } // isReservedHeader checks whether hdr belongs to HTTP2 headers @@ -202,11 +231,11 @@ func contentType(contentSubtype string) string { } func (d *decodeState) status() *status.Status { - if d.statusGen == nil { + if d.data.statusGen == nil { // No status-details were provided; generate status using code/msg. - d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg) + d.data.statusGen = status.New(codes.Code(int32(*(d.data.rawStatusCode))), d.data.rawStatusMsg) } - return d.statusGen + return d.data.statusGen } const binHdrSuffix = "-bin" @@ -244,113 +273,146 @@ func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) error { if frame.Truncated { return status.Error(codes.Internal, "peer header list size exceeded limit") } + for _, hf := range frame.Fields { - if err := d.processHeaderField(hf); err != nil { - return err + d.processHeaderField(hf) + } + + if d.data.isGRPC { + if d.data.grpcErr != nil { + return d.data.grpcErr + } + if d.serverSide { + return nil + } + if d.data.rawStatusCode == nil && d.data.statusGen == nil { + // gRPC status doesn't exist. + // Set rawStatusCode to be unknown and return nil error. + // So that, if the stream has ended this Unknown status + // will be propagated to the user. + // Otherwise, it will be ignored. In which case, status from + // a later trailer, that has StreamEnded flag set, is propagated. + code := int(codes.Unknown) + d.data.rawStatusCode = &code } - } - - if d.serverSide { return nil } - // If grpc status exists, no need to check further. - if d.rawStatusCode != nil || d.statusGen != nil { - return nil + // HTTP fallback mode + if d.data.httpErr != nil { + return d.data.httpErr } - // If grpc status doesn't exist and http status doesn't exist, - // then it's a malformed header. - if d.httpStatus == nil { - return status.Error(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)") - } + var ( + code = codes.Internal // when header does not include HTTP status, return INTERNAL + ok bool + ) - if *(d.httpStatus) != http.StatusOK { - code, ok := httpStatusConvTab[*(d.httpStatus)] + if d.data.httpStatus != nil { + code, ok = HTTPStatusConvTab[*(d.data.httpStatus)] if !ok { code = codes.Unknown } - return status.Error(code, http.StatusText(*(d.httpStatus))) } - // gRPC status doesn't exist and http status is OK. - // Set rawStatusCode to be unknown and return nil error. - // So that, if the stream has ended this Unknown status - // will be propagated to the user. - // Otherwise, it will be ignored. In which case, status from - // a later trailer, that has StreamEnded flag set, is propagated. - code := int(codes.Unknown) - d.rawStatusCode = &code - return nil + return status.Error(code, d.constructHTTPErrMsg()) +} + +// constructErrMsg constructs error message to be returned in HTTP fallback mode. +// Format: HTTP status code and its corresponding message + content-type error message. +func (d *decodeState) constructHTTPErrMsg() string { + var errMsgs []string + + if d.data.httpStatus == nil { + errMsgs = append(errMsgs, "malformed header: missing HTTP status") + } else { + errMsgs = append(errMsgs, fmt.Sprintf("%s: HTTP status code %d", http.StatusText(*(d.data.httpStatus)), *d.data.httpStatus)) + } + + if d.data.contentTypeErr == "" { + errMsgs = append(errMsgs, "transport: missing content-type field") + } else { + errMsgs = append(errMsgs, d.data.contentTypeErr) + } + + return strings.Join(errMsgs, "; ") } func (d *decodeState) addMetadata(k, v string) { - if d.mdata == nil { - d.mdata = make(map[string][]string) + if d.data.mdata == nil { + d.data.mdata = make(map[string][]string) } - d.mdata[k] = append(d.mdata[k], v) + d.data.mdata[k] = append(d.data.mdata[k], v) } -func (d *decodeState) processHeaderField(f hpack.HeaderField) error { +func (d *decodeState) processHeaderField(f hpack.HeaderField) { switch f.Name { case "content-type": contentSubtype, validContentType := contentSubtype(f.Value) if !validContentType { - return status.Errorf(codes.Internal, "transport: received the unexpected content-type %q", f.Value) + d.data.contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", f.Value) + return } - d.contentSubtype = contentSubtype + d.data.contentSubtype = contentSubtype // TODO: do we want to propagate the whole content-type in the metadata, // or come up with a way to just propagate the content-subtype if it was set? // ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"} // in the metadata? d.addMetadata(f.Name, f.Value) + d.data.isGRPC = true case "grpc-encoding": - d.encoding = f.Value + d.data.encoding = f.Value case "grpc-status": code, err := strconv.Atoi(f.Value) if err != nil { - return status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err) + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err) + return } - d.rawStatusCode = &code + d.data.rawStatusCode = &code case "grpc-message": - d.rawStatusMsg = decodeGrpcMessage(f.Value) + d.data.rawStatusMsg = decodeGrpcMessage(f.Value) case "grpc-status-details-bin": v, err := decodeBinHeader(f.Value) if err != nil { - return status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + return } s := &spb.Status{} if err := proto.Unmarshal(v, s); err != nil { - return status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + return } - d.statusGen = status.FromProto(s) + d.data.statusGen = status.FromProto(s) case "grpc-timeout": - d.timeoutSet = true + d.data.timeoutSet = true var err error - if d.timeout, err = decodeTimeout(f.Value); err != nil { - return status.Errorf(codes.Internal, "transport: malformed time-out: %v", err) + if d.data.timeout, err = decodeTimeout(f.Value); err != nil { + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed time-out: %v", err) } case ":path": - d.method = f.Value + d.data.method = f.Value case ":status": code, err := strconv.Atoi(f.Value) if err != nil { - return status.Errorf(codes.Internal, "transport: malformed http-status: %v", err) + d.data.httpErr = status.Errorf(codes.Internal, "transport: malformed http-status: %v", err) + return } - d.httpStatus = &code + d.data.httpStatus = &code case "grpc-tags-bin": v, err := decodeBinHeader(f.Value) if err != nil { - return status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) + return } - d.statsTags = v + d.data.statsTags = v d.addMetadata(f.Name, string(v)) case "grpc-trace-bin": v, err := decodeBinHeader(f.Value) if err != nil { - return status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) + d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) + return } - d.statsTrace = v + d.data.statsTrace = v d.addMetadata(f.Name, string(v)) default: if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) { @@ -359,11 +421,10 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error { v, err := decodeMetadataHeader(f.Name, f.Value) if err != nil { errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err) - return nil + return } d.addMetadata(f.Name, v) } - return nil } type timeoutUnit uint8 diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 2580aa7d3..e0501c998 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -327,8 +327,7 @@ func (s *Stream) TrailersOnly() (bool, error) { if err != nil { return false, err } - // if !headerDone, some other connection error occurred. - return s.noHeaders && atomic.LoadUint32(&s.headerDone) == 1, nil + return s.noHeaders, nil } // Trailer returns the cached trailer metedata. Note that if it is not called diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index e3857356a..baea6befd 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -19,7 +19,6 @@ package transport import ( - "bufio" "bytes" "context" "encoding/binary" @@ -28,7 +27,6 @@ import ( "io" "math" "net" - "net/http" "reflect" "runtime" "strconv" @@ -1943,167 +1941,6 @@ func waitWhileTrue(t *testing.T, condition func() (bool, error)) { } } -// A function of type writeHeaders writes out -// http status with the given stream ID using the given framer. -type writeHeaders func(*http2.Framer, uint32, int) error - -func writeOneHeader(framer *http2.Framer, sid uint32, httpStatus int) error { - var buf bytes.Buffer - henc := hpack.NewEncoder(&buf) - henc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(httpStatus)}) - return framer.WriteHeaders(http2.HeadersFrameParam{ - StreamID: sid, - BlockFragment: buf.Bytes(), - EndStream: true, - EndHeaders: true, - }) -} - -func writeTwoHeaders(framer *http2.Framer, sid uint32, httpStatus int) error { - var buf bytes.Buffer - henc := hpack.NewEncoder(&buf) - henc.WriteField(hpack.HeaderField{ - Name: ":status", - Value: fmt.Sprint(http.StatusOK), - }) - if err := framer.WriteHeaders(http2.HeadersFrameParam{ - StreamID: sid, - BlockFragment: buf.Bytes(), - EndHeaders: true, - }); err != nil { - return err - } - buf.Reset() - henc.WriteField(hpack.HeaderField{ - Name: ":status", - Value: fmt.Sprint(httpStatus), - }) - return framer.WriteHeaders(http2.HeadersFrameParam{ - StreamID: sid, - BlockFragment: buf.Bytes(), - EndStream: true, - EndHeaders: true, - }) -} - -type httpServer struct { - httpStatus int - wh writeHeaders -} - -func (s *httpServer) start(t *testing.T, lis net.Listener) { - // Launch an HTTP server to send back header with httpStatus. - go func() { - conn, err := lis.Accept() - if err != nil { - t.Errorf("Error accepting connection: %v", err) - return - } - defer conn.Close() - // Read preface sent by client. - if _, err = io.ReadFull(conn, make([]byte, len(http2.ClientPreface))); err != nil { - t.Errorf("Error at server-side while reading preface from client. Err: %v", err) - return - } - reader := bufio.NewReaderSize(conn, defaultWriteBufSize) - writer := bufio.NewWriterSize(conn, defaultReadBufSize) - framer := http2.NewFramer(writer, reader) - if err = framer.WriteSettingsAck(); err != nil { - t.Errorf("Error at server-side while sending Settings ack. Err: %v", err) - return - } - var sid uint32 - // Read frames until a header is received. - for { - frame, err := framer.ReadFrame() - if err != nil { - t.Errorf("Error at server-side while reading frame. Err: %v", err) - return - } - if hframe, ok := frame.(*http2.HeadersFrame); ok { - sid = hframe.Header().StreamID - break - } - } - if err = s.wh(framer, sid, s.httpStatus); err != nil { - t.Errorf("Error at server-side while writing headers. Err: %v", err) - return - } - writer.Flush() - }() -} - -func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (*Stream, func()) { - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to listen. Err: %v", err) - } - server := &httpServer{ - httpStatus: httpStatus, - wh: wh, - } - server.start(t, lis) - connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) - defer cancel() - client, err := newHTTP2Client(connectCtx, context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}, func() {}, func(GoAwayReason) {}, func() {}) - if err != nil { - lis.Close() - t.Fatalf("Error creating client. Err: %v", err) - } - stream, err := client.NewStream(context.Background(), &CallHdr{Method: "bogus/method"}) - if err != nil { - client.Close() - lis.Close() - t.Fatalf("Error creating stream at client-side. Err: %v", err) - } - return stream, func() { - client.Close() - lis.Close() - } -} - -func TestHTTPToGRPCStatusMapping(t *testing.T) { - for k := range httpStatusConvTab { - testHTTPToGRPCStatusMapping(t, k, writeOneHeader) - } -} - -func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders) { - stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh) - defer cleanUp() - want := httpStatusConvTab[httpStatus] - buf := make([]byte, 8) - _, err := stream.Read(buf) - if err == nil { - t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want) - } - serr, ok := status.FromError(err) - if !ok { - t.Fatalf("err.(Type) = %T, want status error", err) - } - if want != serr.Code() { - t.Fatalf("Want error code: %v, got: %v", want, serr.Code()) - } -} - -func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) { - stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader) - defer cleanUp() - buf := make([]byte, 8) - _, err := stream.Read(buf) - if err != io.EOF { - t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err) - } - want := codes.Unknown - if stream.status.Code() != want { - t.Fatalf("Status code of stream: %v, want: %v", stream.status.Code(), want) - } -} - -func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) { - testHTTPToGRPCStatusMapping(t, http.StatusUnauthorized, writeTwoHeaders) -} - // If any error occurs on a call to Stream.Read, future calls // should continue to return that same error. func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { diff --git a/test/end2end_test.go b/test/end2end_test.go index 2ffa2e334..80a72d649 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -22,6 +22,7 @@ package test import ( + "bufio" "bytes" "context" "crypto/tls" @@ -45,6 +46,7 @@ import ( "github.com/golang/protobuf/proto" anypb "github.com/golang/protobuf/ptypes/any" "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc" "google.golang.org/grpc/balancer/roundrobin" @@ -62,6 +64,7 @@ import ( "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" @@ -6892,7 +6895,7 @@ func testClientMaxHeaderListSizeServerIntentionalViolation(t *testing.T, e env) time.Sleep(100 * time.Millisecond) rcw.writeHeaders(http2.HeadersFrameParam{ StreamID: tc.getCurrentStreamID(), - BlockFragment: rcw.encodeHeader("oversize", strings.Join(val, "")), + BlockFragment: rcw.encodeRawHeader("oversize", strings.Join(val, "")), EndStream: false, EndHeaders: true, }) @@ -7125,3 +7128,221 @@ func (s) TestRPCWaitsForResolver(t *testing.T) { t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, nil", err) } } + +func (s) TestHTTPHeaderFrameErrorHandlingHTTPMode(t *testing.T) { + // Non-gRPC content-type fallback path. + for httpCode := range transport.HTTPStatusConvTab { + doHTTPHeaderTest(t, transport.HTTPStatusConvTab[int(httpCode)], []string{ + ":status", fmt.Sprintf("%d", httpCode), + "content-type", "text/html", // non-gRPC content type to switch to HTTP mode. + "grpc-status", "1", // Make up a gRPC status error + "grpc-status-details-bin", "???", // Make up a gRPC field parsing error + }) + } + + // Missing content-type fallback path. + for httpCode := range transport.HTTPStatusConvTab { + doHTTPHeaderTest(t, transport.HTTPStatusConvTab[int(httpCode)], []string{ + ":status", fmt.Sprintf("%d", httpCode), + // Omitting content type to switch to HTTP mode. + "grpc-status", "1", // Make up a gRPC status error + "grpc-status-details-bin", "???", // Make up a gRPC field parsing error + }) + } + + // Malformed HTTP status when fallback. + doHTTPHeaderTest(t, codes.Internal, []string{ + ":status", "abc", + // Omitting content type to switch to HTTP mode. + "grpc-status", "1", // Make up a gRPC status error + "grpc-status-details-bin", "???", // Make up a gRPC field parsing error + }) +} + +// Testing erroneous ReponseHeader or Trailers-only (delivered in the first HEADERS frame). +func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) { + for _, test := range []struct { + header []string + errCode codes.Code + }{ + { + // missing gRPC status. + header: []string{ + ":status", "403", + "content-type", "application/grpc", + }, + errCode: codes.Unknown, + }, + { + // malformed grpc-status. + header: []string{ + ":status", "502", + "content-type", "application/grpc", + "grpc-status", "abc", + }, + errCode: codes.Internal, + }, + { + // Malformed grpc-tags-bin field. + header: []string{ + ":status", "502", + "content-type", "application/grpc", + "grpc-status", "0", + "grpc-tags-bin", "???", + }, + errCode: codes.Internal, + }, + { + // gRPC status error. + header: []string{ + ":status", "502", + "content-type", "application/grpc", + "grpc-status", "3", + }, + errCode: codes.InvalidArgument, + }, + } { + doHTTPHeaderTest(t, test.errCode, test.header) + } +} + +// Testing non-Trailers-only Trailers (delievered in second HEADERS frame) +func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) { + for _, test := range []struct { + responseHeader []string + trailer []string + errCode codes.Code + }{ + { + responseHeader: []string{ + ":status", "200", + "content-type", "application/grpc", + }, + trailer: []string{ + // trailer missing grpc-status + ":status", "502", + }, + errCode: codes.Unknown, + }, + { + responseHeader: []string{ + ":status", "404", + "content-type", "application/grpc", + }, + trailer: []string{ + // malformed grpc-status-details-bin field + "grpc-status", "0", + "grpc-status-details-bin", "????", + }, + errCode: codes.Internal, + }, + } { + doHTTPHeaderTest(t, test.errCode, test.responseHeader, test.trailer) + } +} + +func (s) TestHTTPHeaderFrameErrorHandlingMoreThanTwoHeaders(t *testing.T) { + header := []string{ + ":status", "200", + "content-type", "application/grpc", + } + doHTTPHeaderTest(t, codes.Internal, header, header, header) +} + +type httpServer struct { + headerFields [][]string +} + +func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error { + if len(headerFields)%2 == 1 { + panic("odd number of kv args") + } + + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) + for len(headerFields) > 0 { + k, v := headerFields[0], headerFields[1] + headerFields = headerFields[2:] + henc.WriteField(hpack.HeaderField{Name: k, Value: v}) + } + + return framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: sid, + BlockFragment: buf.Bytes(), + EndStream: endStream, + EndHeaders: true, + }) +} + +func (s *httpServer) start(t *testing.T, lis net.Listener) { + // Launch an HTTP server to send back header. + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error accepting connection: %v", err) + return + } + defer conn.Close() + // Read preface sent by client. + if _, err = io.ReadFull(conn, make([]byte, len(http2.ClientPreface))); err != nil { + t.Errorf("Error at server-side while reading preface from client. Err: %v", err) + return + } + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + framer := http2.NewFramer(writer, reader) + if err = framer.WriteSettingsAck(); err != nil { + t.Errorf("Error at server-side while sending Settings ack. Err: %v", err) + return + } + writer.Flush() // necessary since client is expecting preface before declaring connection fully setup. + + var sid uint32 + // Read frames until a header is received. + for { + frame, err := framer.ReadFrame() + if err != nil { + t.Errorf("Error at server-side while reading frame. Err: %v", err) + return + } + if hframe, ok := frame.(*http2.HeadersFrame); ok { + sid = hframe.Header().StreamID + break + } + } + for i, headers := range s.headerFields { + if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil { + t.Errorf("Error at server-side while writing headers. Err: %v", err) + return + } + writer.Flush() + } + }() +} + +func doHTTPHeaderTest(t *testing.T, errCode codes.Code, headerFields ...[]string) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen. Err: %v", err) + } + defer lis.Close() + server := &httpServer{ + headerFields: headerFields, + } + server.start(t, lis) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + cc, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("failed to dial due to err: %v", err) + } + defer cc.Close() + client := testpb.NewTestServiceClient(cc) + stream, err := client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("error creating stream due to err: %v", err) + } + if _, err := stream.Recv(); err == nil || status.Code(err) != errCode { + t.Fatalf("stream.Recv() = _, %v, want error code: %v", err, errCode) + } +} diff --git a/test/rawConnWrapper.go b/test/rawConnWrapper.go index 5d991cf01..124b10e09 100644 --- a/test/rawConnWrapper.go +++ b/test/rawConnWrapper.go @@ -227,6 +227,47 @@ func (rcw *rawConnWrapper) encodeHeaderField(k, v string) error { return nil } +// encodeRawHeader is for usage on both client and server side to construct header based on the input +// key, value pairs. +func (rcw *rawConnWrapper) encodeRawHeader(headers ...string) []byte { + if len(headers)%2 == 1 { + panic("odd number of kv args") + } + + rcw.headerBuf.Reset() + + pseudoCount := map[string]int{} + var keys []string + vals := map[string][]string{} + + for len(headers) > 0 { + k, v := headers[0], headers[1] + headers = headers[2:] + if _, ok := vals[k]; !ok { + keys = append(keys, k) + } + if strings.HasPrefix(k, ":") { + pseudoCount[k]++ + if pseudoCount[k] == 1 { + vals[k] = []string{v} + } else { + // Allows testing of invalid headers w/ dup pseudo fields. + vals[k] = append(vals[k], v) + } + } else { + vals[k] = append(vals[k], v) + } + } + for _, k := range keys { + for _, v := range vals[k] { + rcw.encodeHeaderField(k, v) + } + } + return rcw.headerBuf.Bytes() +} + +// encodeHeader is for usage on client side to write request header. +// // encodeHeader encodes headers and returns their HPACK bytes. headers // must contain an even number of key/value pairs. There may be // multiple pairs for keys (e.g. "cookie"). The :method, :path, and @@ -288,6 +329,7 @@ func (rcw *rawConnWrapper) encodeHeader(headers ...string) []byte { return rcw.headerBuf.Bytes() } +// writeHeadersGRPC is for usage on client side to write request header. func (rcw *rawConnWrapper) writeHeadersGRPC(streamID uint32, path string) { rcw.writeHeaders(http2.HeadersFrameParam{ StreamID: streamID,