client: handle HTTP header parsing error correctly (#2599)

This commit is contained in:
lyuxuan 2019-03-06 10:59:01 -08:00 committed by GitHub
parent 45890ffd9e
commit 79c9bc6794
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 431 additions and 252 deletions

View File

@ -1140,15 +1140,30 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if !ok {
return
}
endStream := frame.StreamEnded()
atomic.StoreUint32(&s.bytesReceived, 1)
var state decodeState
if err := state.decodeHeader(frame); err != nil {
t.closeStream(s, err, true, http2.ErrCodeProtocol, status.New(codes.Internal, err.Error()), nil, false)
// Something wrong. Stops reading even when there is remaining.
initialHeader := atomic.SwapUint32(&s.headerDone, 1) == 0
if !initialHeader && !endStream {
// As specified by RFC 7540, a HEADERS frame (and associated CONTINUATION frames) can only appear
// at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set.
st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream")
t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false)
return
}
state := &decodeState{
serverSide: false,
ignoreContentType: !initialHeader,
}
// Initialize isGRPC value to be !initialHeader, since if a gRPC ResponseHeader has been received
// which indicates peer speaking gRPC, we are in gRPC mode.
state.data.isGRPC = !initialHeader
if err := state.decodeHeader(frame); err != nil {
t.closeStream(s, err, true, http2.ErrCodeProtocol, status.Convert(err), nil, endStream)
return
}
endStream := frame.StreamEnded()
var isHeader bool
defer func() {
if t.statsHandler != nil {
@ -1167,29 +1182,30 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
}
}
}()
// If headers haven't been received yet.
if atomic.SwapUint32(&s.headerDone, 1) == 0 {
if initialHeader {
if !endStream {
// Headers frame is not actually a trailers-only frame.
// Headers frame is ResponseHeader.
isHeader = true
// These values can be set without any synchronization because
// stream goroutine will read it only after seeing a closed
// headerChan which we'll close after setting this.
s.recvCompress = state.encoding
if len(state.mdata) > 0 {
s.header = state.mdata
s.recvCompress = state.data.encoding
if len(state.data.mdata) > 0 {
s.header = state.data.mdata
}
} else {
s.noHeaders = true
close(s.headerChan)
return
}
// Headers frame is Trailers-only.
s.noHeaders = true
close(s.headerChan)
}
if !endStream {
return
}
// if client received END_STREAM from server while stream was still active, send RST_STREAM
rst := s.getState() == streamActive
t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.mdata, true)
t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.data.mdata, true)
}
// reader runs as a separate goroutine in charge of reading data from network

View File

@ -286,7 +286,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
// operateHeader takes action on the decoded headers.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (fatal bool) {
streamID := frame.Header().StreamID
state := decodeState{serverSide: true}
state := &decodeState{
serverSide: true,
ignoreContentType: false,
}
if err := state.decodeHeader(frame); err != nil {
if se, ok := status.FromError(err); ok {
t.controlBuf.put(&cleanupStream{
@ -305,16 +308,16 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
recvCompress: state.encoding,
method: state.method,
contentSubtype: state.contentSubtype,
recvCompress: state.data.encoding,
method: state.data.method,
contentSubtype: state.data.contentSubtype,
}
if frame.StreamEnded() {
// s is just created by the caller. No lock needed.
s.state = streamReadDone
}
if state.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout)
if state.data.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(t.ctx, state.data.timeout)
} else {
s.ctx, s.cancel = context.WithCancel(t.ctx)
}
@ -327,19 +330,19 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
}
s.ctx = peer.NewContext(s.ctx, pr)
// Attach the received metadata to the context.
if len(state.mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
if len(state.data.mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata)
}
if state.statsTags != nil {
s.ctx = stats.SetIncomingTags(s.ctx, state.statsTags)
if state.data.statsTags != nil {
s.ctx = stats.SetIncomingTags(s.ctx, state.data.statsTags)
}
if state.statsTrace != nil {
s.ctx = stats.SetIncomingTrace(s.ctx, state.statsTrace)
if state.data.statsTrace != nil {
s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace)
}
if t.inTapHandle != nil {
var err error
info := &tap.Info{
FullMethodName: state.method,
FullMethodName: state.data.method,
}
s.ctx, err = t.inTapHandle(s.ctx, info)
if err != nil {

View File

@ -78,7 +78,8 @@ var (
codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
codes.PermissionDenied: http2.ErrCodeInadequateSecurity,
}
httpStatusConvTab = map[int]codes.Code{
// HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table.
HTTPStatusConvTab = map[int]codes.Code{
// 400 Bad Request - INTERNAL.
http.StatusBadRequest: codes.Internal,
// 401 Unauthorized - UNAUTHENTICATED.
@ -98,9 +99,7 @@ var (
}
)
// Records the states during HPACK decoding. Must be reset once the
// decoding of the entire headers are finished.
type decodeState struct {
type parsedHeaderData struct {
encoding string
// statusGen caches the stream status received from the trailer the server
// sent. Client side only. Do not access directly. After all trailers are
@ -120,8 +119,38 @@ type decodeState struct {
statsTags []byte
statsTrace []byte
contentSubtype string
// isGRPC field indicates whether the peer is speaking gRPC (otherwise HTTP).
//
// We are in gRPC mode (peer speaking gRPC) if:
// * We are client side and have already received a HEADER frame that indicates gRPC peer.
// * The header contains valid a content-type, i.e. a string starts with "application/grpc"
// And we should handle error specific to gRPC.
//
// Otherwise (i.e. a content-type string starts without "application/grpc", or does not exist), we
// are in HTTP fallback mode, and should handle error specific to HTTP.
isGRPC bool
grpcErr error
httpErr error
contentTypeErr string
}
// decodeState configures decoding criteria and records the decoded data.
type decodeState struct {
// whether decoding on server side or not
serverSide bool
// ignoreContentType indicates whether when processing the HEADERS frame, ignoring checking the
// content-type is grpc or not.
//
// Trailers (after headers) should not have a content-type. And thus we will ignore checking the
// content-type.
//
// For server, this field is always false.
ignoreContentType bool
// Records the states during HPACK decoding. It will be filled with info parsed from HTTP HEADERS
// frame once decodeHeader function has been invoked and returned.
data parsedHeaderData
}
// isReservedHeader checks whether hdr belongs to HTTP2 headers
@ -202,11 +231,11 @@ func contentType(contentSubtype string) string {
}
func (d *decodeState) status() *status.Status {
if d.statusGen == nil {
if d.data.statusGen == nil {
// No status-details were provided; generate status using code/msg.
d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg)
d.data.statusGen = status.New(codes.Code(int32(*(d.data.rawStatusCode))), d.data.rawStatusMsg)
}
return d.statusGen
return d.data.statusGen
}
const binHdrSuffix = "-bin"
@ -244,113 +273,146 @@ func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) error {
if frame.Truncated {
return status.Error(codes.Internal, "peer header list size exceeded limit")
}
for _, hf := range frame.Fields {
if err := d.processHeaderField(hf); err != nil {
return err
d.processHeaderField(hf)
}
if d.data.isGRPC {
if d.data.grpcErr != nil {
return d.data.grpcErr
}
if d.serverSide {
return nil
}
if d.data.rawStatusCode == nil && d.data.statusGen == nil {
// gRPC status doesn't exist.
// Set rawStatusCode to be unknown and return nil error.
// So that, if the stream has ended this Unknown status
// will be propagated to the user.
// Otherwise, it will be ignored. In which case, status from
// a later trailer, that has StreamEnded flag set, is propagated.
code := int(codes.Unknown)
d.data.rawStatusCode = &code
}
}
if d.serverSide {
return nil
}
// If grpc status exists, no need to check further.
if d.rawStatusCode != nil || d.statusGen != nil {
return nil
// HTTP fallback mode
if d.data.httpErr != nil {
return d.data.httpErr
}
// If grpc status doesn't exist and http status doesn't exist,
// then it's a malformed header.
if d.httpStatus == nil {
return status.Error(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)")
}
var (
code = codes.Internal // when header does not include HTTP status, return INTERNAL
ok bool
)
if *(d.httpStatus) != http.StatusOK {
code, ok := httpStatusConvTab[*(d.httpStatus)]
if d.data.httpStatus != nil {
code, ok = HTTPStatusConvTab[*(d.data.httpStatus)]
if !ok {
code = codes.Unknown
}
return status.Error(code, http.StatusText(*(d.httpStatus)))
}
// gRPC status doesn't exist and http status is OK.
// Set rawStatusCode to be unknown and return nil error.
// So that, if the stream has ended this Unknown status
// will be propagated to the user.
// Otherwise, it will be ignored. In which case, status from
// a later trailer, that has StreamEnded flag set, is propagated.
code := int(codes.Unknown)
d.rawStatusCode = &code
return nil
return status.Error(code, d.constructHTTPErrMsg())
}
// constructErrMsg constructs error message to be returned in HTTP fallback mode.
// Format: HTTP status code and its corresponding message + content-type error message.
func (d *decodeState) constructHTTPErrMsg() string {
var errMsgs []string
if d.data.httpStatus == nil {
errMsgs = append(errMsgs, "malformed header: missing HTTP status")
} else {
errMsgs = append(errMsgs, fmt.Sprintf("%s: HTTP status code %d", http.StatusText(*(d.data.httpStatus)), *d.data.httpStatus))
}
if d.data.contentTypeErr == "" {
errMsgs = append(errMsgs, "transport: missing content-type field")
} else {
errMsgs = append(errMsgs, d.data.contentTypeErr)
}
return strings.Join(errMsgs, "; ")
}
func (d *decodeState) addMetadata(k, v string) {
if d.mdata == nil {
d.mdata = make(map[string][]string)
if d.data.mdata == nil {
d.data.mdata = make(map[string][]string)
}
d.mdata[k] = append(d.mdata[k], v)
d.data.mdata[k] = append(d.data.mdata[k], v)
}
func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
func (d *decodeState) processHeaderField(f hpack.HeaderField) {
switch f.Name {
case "content-type":
contentSubtype, validContentType := contentSubtype(f.Value)
if !validContentType {
return status.Errorf(codes.Internal, "transport: received the unexpected content-type %q", f.Value)
d.data.contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", f.Value)
return
}
d.contentSubtype = contentSubtype
d.data.contentSubtype = contentSubtype
// TODO: do we want to propagate the whole content-type in the metadata,
// or come up with a way to just propagate the content-subtype if it was set?
// ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"}
// in the metadata?
d.addMetadata(f.Name, f.Value)
d.data.isGRPC = true
case "grpc-encoding":
d.encoding = f.Value
d.data.encoding = f.Value
case "grpc-status":
code, err := strconv.Atoi(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err)
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err)
return
}
d.rawStatusCode = &code
d.data.rawStatusCode = &code
case "grpc-message":
d.rawStatusMsg = decodeGrpcMessage(f.Value)
d.data.rawStatusMsg = decodeGrpcMessage(f.Value)
case "grpc-status-details-bin":
v, err := decodeBinHeader(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
return
}
s := &spb.Status{}
if err := proto.Unmarshal(v, s); err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
return
}
d.statusGen = status.FromProto(s)
d.data.statusGen = status.FromProto(s)
case "grpc-timeout":
d.timeoutSet = true
d.data.timeoutSet = true
var err error
if d.timeout, err = decodeTimeout(f.Value); err != nil {
return status.Errorf(codes.Internal, "transport: malformed time-out: %v", err)
if d.data.timeout, err = decodeTimeout(f.Value); err != nil {
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed time-out: %v", err)
}
case ":path":
d.method = f.Value
d.data.method = f.Value
case ":status":
code, err := strconv.Atoi(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed http-status: %v", err)
d.data.httpErr = status.Errorf(codes.Internal, "transport: malformed http-status: %v", err)
return
}
d.httpStatus = &code
d.data.httpStatus = &code
case "grpc-tags-bin":
v, err := decodeBinHeader(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err)
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err)
return
}
d.statsTags = v
d.data.statsTags = v
d.addMetadata(f.Name, string(v))
case "grpc-trace-bin":
v, err := decodeBinHeader(f.Value)
if err != nil {
return status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err)
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err)
return
}
d.statsTrace = v
d.data.statsTrace = v
d.addMetadata(f.Name, string(v))
default:
if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) {
@ -359,11 +421,10 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
v, err := decodeMetadataHeader(f.Name, f.Value)
if err != nil {
errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err)
return nil
return
}
d.addMetadata(f.Name, v)
}
return nil
}
type timeoutUnit uint8

View File

@ -327,8 +327,7 @@ func (s *Stream) TrailersOnly() (bool, error) {
if err != nil {
return false, err
}
// if !headerDone, some other connection error occurred.
return s.noHeaders && atomic.LoadUint32(&s.headerDone) == 1, nil
return s.noHeaders, nil
}
// Trailer returns the cached trailer metedata. Note that if it is not called

View File

@ -19,7 +19,6 @@
package transport
import (
"bufio"
"bytes"
"context"
"encoding/binary"
@ -28,7 +27,6 @@ import (
"io"
"math"
"net"
"net/http"
"reflect"
"runtime"
"strconv"
@ -1943,167 +1941,6 @@ func waitWhileTrue(t *testing.T, condition func() (bool, error)) {
}
}
// A function of type writeHeaders writes out
// http status with the given stream ID using the given framer.
type writeHeaders func(*http2.Framer, uint32, int) error
func writeOneHeader(framer *http2.Framer, sid uint32, httpStatus int) error {
var buf bytes.Buffer
henc := hpack.NewEncoder(&buf)
henc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(httpStatus)})
return framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: sid,
BlockFragment: buf.Bytes(),
EndStream: true,
EndHeaders: true,
})
}
func writeTwoHeaders(framer *http2.Framer, sid uint32, httpStatus int) error {
var buf bytes.Buffer
henc := hpack.NewEncoder(&buf)
henc.WriteField(hpack.HeaderField{
Name: ":status",
Value: fmt.Sprint(http.StatusOK),
})
if err := framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: sid,
BlockFragment: buf.Bytes(),
EndHeaders: true,
}); err != nil {
return err
}
buf.Reset()
henc.WriteField(hpack.HeaderField{
Name: ":status",
Value: fmt.Sprint(httpStatus),
})
return framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: sid,
BlockFragment: buf.Bytes(),
EndStream: true,
EndHeaders: true,
})
}
type httpServer struct {
httpStatus int
wh writeHeaders
}
func (s *httpServer) start(t *testing.T, lis net.Listener) {
// Launch an HTTP server to send back header with httpStatus.
go func() {
conn, err := lis.Accept()
if err != nil {
t.Errorf("Error accepting connection: %v", err)
return
}
defer conn.Close()
// Read preface sent by client.
if _, err = io.ReadFull(conn, make([]byte, len(http2.ClientPreface))); err != nil {
t.Errorf("Error at server-side while reading preface from client. Err: %v", err)
return
}
reader := bufio.NewReaderSize(conn, defaultWriteBufSize)
writer := bufio.NewWriterSize(conn, defaultReadBufSize)
framer := http2.NewFramer(writer, reader)
if err = framer.WriteSettingsAck(); err != nil {
t.Errorf("Error at server-side while sending Settings ack. Err: %v", err)
return
}
var sid uint32
// Read frames until a header is received.
for {
frame, err := framer.ReadFrame()
if err != nil {
t.Errorf("Error at server-side while reading frame. Err: %v", err)
return
}
if hframe, ok := frame.(*http2.HeadersFrame); ok {
sid = hframe.Header().StreamID
break
}
}
if err = s.wh(framer, sid, s.httpStatus); err != nil {
t.Errorf("Error at server-side while writing headers. Err: %v", err)
return
}
writer.Flush()
}()
}
func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (*Stream, func()) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen. Err: %v", err)
}
server := &httpServer{
httpStatus: httpStatus,
wh: wh,
}
server.start(t, lis)
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
defer cancel()
client, err := newHTTP2Client(connectCtx, context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}, func() {}, func(GoAwayReason) {}, func() {})
if err != nil {
lis.Close()
t.Fatalf("Error creating client. Err: %v", err)
}
stream, err := client.NewStream(context.Background(), &CallHdr{Method: "bogus/method"})
if err != nil {
client.Close()
lis.Close()
t.Fatalf("Error creating stream at client-side. Err: %v", err)
}
return stream, func() {
client.Close()
lis.Close()
}
}
func TestHTTPToGRPCStatusMapping(t *testing.T) {
for k := range httpStatusConvTab {
testHTTPToGRPCStatusMapping(t, k, writeOneHeader)
}
}
func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders) {
stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh)
defer cleanUp()
want := httpStatusConvTab[httpStatus]
buf := make([]byte, 8)
_, err := stream.Read(buf)
if err == nil {
t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want)
}
serr, ok := status.FromError(err)
if !ok {
t.Fatalf("err.(Type) = %T, want status error", err)
}
if want != serr.Code() {
t.Fatalf("Want error code: %v, got: %v", want, serr.Code())
}
}
func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) {
stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader)
defer cleanUp()
buf := make([]byte, 8)
_, err := stream.Read(buf)
if err != io.EOF {
t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err)
}
want := codes.Unknown
if stream.status.Code() != want {
t.Fatalf("Status code of stream: %v, want: %v", stream.status.Code(), want)
}
}
func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) {
testHTTPToGRPCStatusMapping(t, http.StatusUnauthorized, writeTwoHeaders)
}
// If any error occurs on a call to Stream.Read, future calls
// should continue to return that same error.
func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {

View File

@ -22,6 +22,7 @@
package test
import (
"bufio"
"bytes"
"context"
"crypto/tls"
@ -45,6 +46,7 @@ import (
"github.com/golang/protobuf/proto"
anypb "github.com/golang/protobuf/ptypes/any"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer/roundrobin"
@ -62,6 +64,7 @@ import (
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
@ -6892,7 +6895,7 @@ func testClientMaxHeaderListSizeServerIntentionalViolation(t *testing.T, e env)
time.Sleep(100 * time.Millisecond)
rcw.writeHeaders(http2.HeadersFrameParam{
StreamID: tc.getCurrentStreamID(),
BlockFragment: rcw.encodeHeader("oversize", strings.Join(val, "")),
BlockFragment: rcw.encodeRawHeader("oversize", strings.Join(val, "")),
EndStream: false,
EndHeaders: true,
})
@ -7125,3 +7128,221 @@ func (s) TestRPCWaitsForResolver(t *testing.T) {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, nil", err)
}
}
func (s) TestHTTPHeaderFrameErrorHandlingHTTPMode(t *testing.T) {
// Non-gRPC content-type fallback path.
for httpCode := range transport.HTTPStatusConvTab {
doHTTPHeaderTest(t, transport.HTTPStatusConvTab[int(httpCode)], []string{
":status", fmt.Sprintf("%d", httpCode),
"content-type", "text/html", // non-gRPC content type to switch to HTTP mode.
"grpc-status", "1", // Make up a gRPC status error
"grpc-status-details-bin", "???", // Make up a gRPC field parsing error
})
}
// Missing content-type fallback path.
for httpCode := range transport.HTTPStatusConvTab {
doHTTPHeaderTest(t, transport.HTTPStatusConvTab[int(httpCode)], []string{
":status", fmt.Sprintf("%d", httpCode),
// Omitting content type to switch to HTTP mode.
"grpc-status", "1", // Make up a gRPC status error
"grpc-status-details-bin", "???", // Make up a gRPC field parsing error
})
}
// Malformed HTTP status when fallback.
doHTTPHeaderTest(t, codes.Internal, []string{
":status", "abc",
// Omitting content type to switch to HTTP mode.
"grpc-status", "1", // Make up a gRPC status error
"grpc-status-details-bin", "???", // Make up a gRPC field parsing error
})
}
// Testing erroneous ReponseHeader or Trailers-only (delivered in the first HEADERS frame).
func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) {
for _, test := range []struct {
header []string
errCode codes.Code
}{
{
// missing gRPC status.
header: []string{
":status", "403",
"content-type", "application/grpc",
},
errCode: codes.Unknown,
},
{
// malformed grpc-status.
header: []string{
":status", "502",
"content-type", "application/grpc",
"grpc-status", "abc",
},
errCode: codes.Internal,
},
{
// Malformed grpc-tags-bin field.
header: []string{
":status", "502",
"content-type", "application/grpc",
"grpc-status", "0",
"grpc-tags-bin", "???",
},
errCode: codes.Internal,
},
{
// gRPC status error.
header: []string{
":status", "502",
"content-type", "application/grpc",
"grpc-status", "3",
},
errCode: codes.InvalidArgument,
},
} {
doHTTPHeaderTest(t, test.errCode, test.header)
}
}
// Testing non-Trailers-only Trailers (delievered in second HEADERS frame)
func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) {
for _, test := range []struct {
responseHeader []string
trailer []string
errCode codes.Code
}{
{
responseHeader: []string{
":status", "200",
"content-type", "application/grpc",
},
trailer: []string{
// trailer missing grpc-status
":status", "502",
},
errCode: codes.Unknown,
},
{
responseHeader: []string{
":status", "404",
"content-type", "application/grpc",
},
trailer: []string{
// malformed grpc-status-details-bin field
"grpc-status", "0",
"grpc-status-details-bin", "????",
},
errCode: codes.Internal,
},
} {
doHTTPHeaderTest(t, test.errCode, test.responseHeader, test.trailer)
}
}
func (s) TestHTTPHeaderFrameErrorHandlingMoreThanTwoHeaders(t *testing.T) {
header := []string{
":status", "200",
"content-type", "application/grpc",
}
doHTTPHeaderTest(t, codes.Internal, header, header, header)
}
type httpServer struct {
headerFields [][]string
}
func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error {
if len(headerFields)%2 == 1 {
panic("odd number of kv args")
}
var buf bytes.Buffer
henc := hpack.NewEncoder(&buf)
for len(headerFields) > 0 {
k, v := headerFields[0], headerFields[1]
headerFields = headerFields[2:]
henc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
return framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: sid,
BlockFragment: buf.Bytes(),
EndStream: endStream,
EndHeaders: true,
})
}
func (s *httpServer) start(t *testing.T, lis net.Listener) {
// Launch an HTTP server to send back header.
go func() {
conn, err := lis.Accept()
if err != nil {
t.Errorf("Error accepting connection: %v", err)
return
}
defer conn.Close()
// Read preface sent by client.
if _, err = io.ReadFull(conn, make([]byte, len(http2.ClientPreface))); err != nil {
t.Errorf("Error at server-side while reading preface from client. Err: %v", err)
return
}
reader := bufio.NewReader(conn)
writer := bufio.NewWriter(conn)
framer := http2.NewFramer(writer, reader)
if err = framer.WriteSettingsAck(); err != nil {
t.Errorf("Error at server-side while sending Settings ack. Err: %v", err)
return
}
writer.Flush() // necessary since client is expecting preface before declaring connection fully setup.
var sid uint32
// Read frames until a header is received.
for {
frame, err := framer.ReadFrame()
if err != nil {
t.Errorf("Error at server-side while reading frame. Err: %v", err)
return
}
if hframe, ok := frame.(*http2.HeadersFrame); ok {
sid = hframe.Header().StreamID
break
}
}
for i, headers := range s.headerFields {
if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil {
t.Errorf("Error at server-side while writing headers. Err: %v", err)
return
}
writer.Flush()
}
}()
}
func doHTTPHeaderTest(t *testing.T, errCode codes.Code, headerFields ...[]string) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen. Err: %v", err)
}
defer lis.Close()
server := &httpServer{
headerFields: headerFields,
}
server.start(t, lis)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
cc, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithInsecure())
if err != nil {
t.Fatalf("failed to dial due to err: %v", err)
}
defer cc.Close()
client := testpb.NewTestServiceClient(cc)
stream, err := client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("error creating stream due to err: %v", err)
}
if _, err := stream.Recv(); err == nil || status.Code(err) != errCode {
t.Fatalf("stream.Recv() = _, %v, want error code: %v", err, errCode)
}
}

View File

@ -227,6 +227,47 @@ func (rcw *rawConnWrapper) encodeHeaderField(k, v string) error {
return nil
}
// encodeRawHeader is for usage on both client and server side to construct header based on the input
// key, value pairs.
func (rcw *rawConnWrapper) encodeRawHeader(headers ...string) []byte {
if len(headers)%2 == 1 {
panic("odd number of kv args")
}
rcw.headerBuf.Reset()
pseudoCount := map[string]int{}
var keys []string
vals := map[string][]string{}
for len(headers) > 0 {
k, v := headers[0], headers[1]
headers = headers[2:]
if _, ok := vals[k]; !ok {
keys = append(keys, k)
}
if strings.HasPrefix(k, ":") {
pseudoCount[k]++
if pseudoCount[k] == 1 {
vals[k] = []string{v}
} else {
// Allows testing of invalid headers w/ dup pseudo fields.
vals[k] = append(vals[k], v)
}
} else {
vals[k] = append(vals[k], v)
}
}
for _, k := range keys {
for _, v := range vals[k] {
rcw.encodeHeaderField(k, v)
}
}
return rcw.headerBuf.Bytes()
}
// encodeHeader is for usage on client side to write request header.
//
// encodeHeader encodes headers and returns their HPACK bytes. headers
// must contain an even number of key/value pairs. There may be
// multiple pairs for keys (e.g. "cookie"). The :method, :path, and
@ -288,6 +329,7 @@ func (rcw *rawConnWrapper) encodeHeader(headers ...string) []byte {
return rcw.headerBuf.Bytes()
}
// writeHeadersGRPC is for usage on client side to write request header.
func (rcw *rawConnWrapper) writeHeadersGRPC(streamID uint32, path string) {
rcw.writeHeaders(http2.HeadersFrameParam{
StreamID: streamID,