mirror of https://github.com/grpc/grpc-go.git
client: handle HTTP header parsing error correctly (#2599)
This commit is contained in:
parent
45890ffd9e
commit
79c9bc6794
|
@ -1140,15 +1140,30 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
|
|||
if !ok {
|
||||
return
|
||||
}
|
||||
endStream := frame.StreamEnded()
|
||||
atomic.StoreUint32(&s.bytesReceived, 1)
|
||||
var state decodeState
|
||||
if err := state.decodeHeader(frame); err != nil {
|
||||
t.closeStream(s, err, true, http2.ErrCodeProtocol, status.New(codes.Internal, err.Error()), nil, false)
|
||||
// Something wrong. Stops reading even when there is remaining.
|
||||
initialHeader := atomic.SwapUint32(&s.headerDone, 1) == 0
|
||||
|
||||
if !initialHeader && !endStream {
|
||||
// As specified by RFC 7540, a HEADERS frame (and associated CONTINUATION frames) can only appear
|
||||
// at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set.
|
||||
st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream")
|
||||
t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false)
|
||||
return
|
||||
}
|
||||
|
||||
state := &decodeState{
|
||||
serverSide: false,
|
||||
ignoreContentType: !initialHeader,
|
||||
}
|
||||
// Initialize isGRPC value to be !initialHeader, since if a gRPC ResponseHeader has been received
|
||||
// which indicates peer speaking gRPC, we are in gRPC mode.
|
||||
state.data.isGRPC = !initialHeader
|
||||
if err := state.decodeHeader(frame); err != nil {
|
||||
t.closeStream(s, err, true, http2.ErrCodeProtocol, status.Convert(err), nil, endStream)
|
||||
return
|
||||
}
|
||||
|
||||
endStream := frame.StreamEnded()
|
||||
var isHeader bool
|
||||
defer func() {
|
||||
if t.statsHandler != nil {
|
||||
|
@ -1167,29 +1182,30 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
|
|||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// If headers haven't been received yet.
|
||||
if atomic.SwapUint32(&s.headerDone, 1) == 0 {
|
||||
if initialHeader {
|
||||
if !endStream {
|
||||
// Headers frame is not actually a trailers-only frame.
|
||||
// Headers frame is ResponseHeader.
|
||||
isHeader = true
|
||||
// These values can be set without any synchronization because
|
||||
// stream goroutine will read it only after seeing a closed
|
||||
// headerChan which we'll close after setting this.
|
||||
s.recvCompress = state.encoding
|
||||
if len(state.mdata) > 0 {
|
||||
s.header = state.mdata
|
||||
s.recvCompress = state.data.encoding
|
||||
if len(state.data.mdata) > 0 {
|
||||
s.header = state.data.mdata
|
||||
}
|
||||
} else {
|
||||
s.noHeaders = true
|
||||
close(s.headerChan)
|
||||
return
|
||||
}
|
||||
// Headers frame is Trailers-only.
|
||||
s.noHeaders = true
|
||||
close(s.headerChan)
|
||||
}
|
||||
if !endStream {
|
||||
return
|
||||
}
|
||||
|
||||
// if client received END_STREAM from server while stream was still active, send RST_STREAM
|
||||
rst := s.getState() == streamActive
|
||||
t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.mdata, true)
|
||||
t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.data.mdata, true)
|
||||
}
|
||||
|
||||
// reader runs as a separate goroutine in charge of reading data from network
|
||||
|
|
|
@ -286,7 +286,10 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
|
|||
// operateHeader takes action on the decoded headers.
|
||||
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (fatal bool) {
|
||||
streamID := frame.Header().StreamID
|
||||
state := decodeState{serverSide: true}
|
||||
state := &decodeState{
|
||||
serverSide: true,
|
||||
ignoreContentType: false,
|
||||
}
|
||||
if err := state.decodeHeader(frame); err != nil {
|
||||
if se, ok := status.FromError(err); ok {
|
||||
t.controlBuf.put(&cleanupStream{
|
||||
|
@ -305,16 +308,16 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||
st: t,
|
||||
buf: buf,
|
||||
fc: &inFlow{limit: uint32(t.initialWindowSize)},
|
||||
recvCompress: state.encoding,
|
||||
method: state.method,
|
||||
contentSubtype: state.contentSubtype,
|
||||
recvCompress: state.data.encoding,
|
||||
method: state.data.method,
|
||||
contentSubtype: state.data.contentSubtype,
|
||||
}
|
||||
if frame.StreamEnded() {
|
||||
// s is just created by the caller. No lock needed.
|
||||
s.state = streamReadDone
|
||||
}
|
||||
if state.timeoutSet {
|
||||
s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout)
|
||||
if state.data.timeoutSet {
|
||||
s.ctx, s.cancel = context.WithTimeout(t.ctx, state.data.timeout)
|
||||
} else {
|
||||
s.ctx, s.cancel = context.WithCancel(t.ctx)
|
||||
}
|
||||
|
@ -327,19 +330,19 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||
}
|
||||
s.ctx = peer.NewContext(s.ctx, pr)
|
||||
// Attach the received metadata to the context.
|
||||
if len(state.mdata) > 0 {
|
||||
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
|
||||
if len(state.data.mdata) > 0 {
|
||||
s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata)
|
||||
}
|
||||
if state.statsTags != nil {
|
||||
s.ctx = stats.SetIncomingTags(s.ctx, state.statsTags)
|
||||
if state.data.statsTags != nil {
|
||||
s.ctx = stats.SetIncomingTags(s.ctx, state.data.statsTags)
|
||||
}
|
||||
if state.statsTrace != nil {
|
||||
s.ctx = stats.SetIncomingTrace(s.ctx, state.statsTrace)
|
||||
if state.data.statsTrace != nil {
|
||||
s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace)
|
||||
}
|
||||
if t.inTapHandle != nil {
|
||||
var err error
|
||||
info := &tap.Info{
|
||||
FullMethodName: state.method,
|
||||
FullMethodName: state.data.method,
|
||||
}
|
||||
s.ctx, err = t.inTapHandle(s.ctx, info)
|
||||
if err != nil {
|
||||
|
|
|
@ -78,7 +78,8 @@ var (
|
|||
codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm,
|
||||
codes.PermissionDenied: http2.ErrCodeInadequateSecurity,
|
||||
}
|
||||
httpStatusConvTab = map[int]codes.Code{
|
||||
// HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table.
|
||||
HTTPStatusConvTab = map[int]codes.Code{
|
||||
// 400 Bad Request - INTERNAL.
|
||||
http.StatusBadRequest: codes.Internal,
|
||||
// 401 Unauthorized - UNAUTHENTICATED.
|
||||
|
@ -98,9 +99,7 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
// Records the states during HPACK decoding. Must be reset once the
|
||||
// decoding of the entire headers are finished.
|
||||
type decodeState struct {
|
||||
type parsedHeaderData struct {
|
||||
encoding string
|
||||
// statusGen caches the stream status received from the trailer the server
|
||||
// sent. Client side only. Do not access directly. After all trailers are
|
||||
|
@ -120,8 +119,38 @@ type decodeState struct {
|
|||
statsTags []byte
|
||||
statsTrace []byte
|
||||
contentSubtype string
|
||||
|
||||
// isGRPC field indicates whether the peer is speaking gRPC (otherwise HTTP).
|
||||
//
|
||||
// We are in gRPC mode (peer speaking gRPC) if:
|
||||
// * We are client side and have already received a HEADER frame that indicates gRPC peer.
|
||||
// * The header contains valid a content-type, i.e. a string starts with "application/grpc"
|
||||
// And we should handle error specific to gRPC.
|
||||
//
|
||||
// Otherwise (i.e. a content-type string starts without "application/grpc", or does not exist), we
|
||||
// are in HTTP fallback mode, and should handle error specific to HTTP.
|
||||
isGRPC bool
|
||||
grpcErr error
|
||||
httpErr error
|
||||
contentTypeErr string
|
||||
}
|
||||
|
||||
// decodeState configures decoding criteria and records the decoded data.
|
||||
type decodeState struct {
|
||||
// whether decoding on server side or not
|
||||
serverSide bool
|
||||
// ignoreContentType indicates whether when processing the HEADERS frame, ignoring checking the
|
||||
// content-type is grpc or not.
|
||||
//
|
||||
// Trailers (after headers) should not have a content-type. And thus we will ignore checking the
|
||||
// content-type.
|
||||
//
|
||||
// For server, this field is always false.
|
||||
ignoreContentType bool
|
||||
|
||||
// Records the states during HPACK decoding. It will be filled with info parsed from HTTP HEADERS
|
||||
// frame once decodeHeader function has been invoked and returned.
|
||||
data parsedHeaderData
|
||||
}
|
||||
|
||||
// isReservedHeader checks whether hdr belongs to HTTP2 headers
|
||||
|
@ -202,11 +231,11 @@ func contentType(contentSubtype string) string {
|
|||
}
|
||||
|
||||
func (d *decodeState) status() *status.Status {
|
||||
if d.statusGen == nil {
|
||||
if d.data.statusGen == nil {
|
||||
// No status-details were provided; generate status using code/msg.
|
||||
d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg)
|
||||
d.data.statusGen = status.New(codes.Code(int32(*(d.data.rawStatusCode))), d.data.rawStatusMsg)
|
||||
}
|
||||
return d.statusGen
|
||||
return d.data.statusGen
|
||||
}
|
||||
|
||||
const binHdrSuffix = "-bin"
|
||||
|
@ -244,113 +273,146 @@ func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) error {
|
|||
if frame.Truncated {
|
||||
return status.Error(codes.Internal, "peer header list size exceeded limit")
|
||||
}
|
||||
|
||||
for _, hf := range frame.Fields {
|
||||
if err := d.processHeaderField(hf); err != nil {
|
||||
return err
|
||||
d.processHeaderField(hf)
|
||||
}
|
||||
|
||||
if d.data.isGRPC {
|
||||
if d.data.grpcErr != nil {
|
||||
return d.data.grpcErr
|
||||
}
|
||||
if d.serverSide {
|
||||
return nil
|
||||
}
|
||||
if d.data.rawStatusCode == nil && d.data.statusGen == nil {
|
||||
// gRPC status doesn't exist.
|
||||
// Set rawStatusCode to be unknown and return nil error.
|
||||
// So that, if the stream has ended this Unknown status
|
||||
// will be propagated to the user.
|
||||
// Otherwise, it will be ignored. In which case, status from
|
||||
// a later trailer, that has StreamEnded flag set, is propagated.
|
||||
code := int(codes.Unknown)
|
||||
d.data.rawStatusCode = &code
|
||||
}
|
||||
}
|
||||
|
||||
if d.serverSide {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If grpc status exists, no need to check further.
|
||||
if d.rawStatusCode != nil || d.statusGen != nil {
|
||||
return nil
|
||||
// HTTP fallback mode
|
||||
if d.data.httpErr != nil {
|
||||
return d.data.httpErr
|
||||
}
|
||||
|
||||
// If grpc status doesn't exist and http status doesn't exist,
|
||||
// then it's a malformed header.
|
||||
if d.httpStatus == nil {
|
||||
return status.Error(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)")
|
||||
}
|
||||
var (
|
||||
code = codes.Internal // when header does not include HTTP status, return INTERNAL
|
||||
ok bool
|
||||
)
|
||||
|
||||
if *(d.httpStatus) != http.StatusOK {
|
||||
code, ok := httpStatusConvTab[*(d.httpStatus)]
|
||||
if d.data.httpStatus != nil {
|
||||
code, ok = HTTPStatusConvTab[*(d.data.httpStatus)]
|
||||
if !ok {
|
||||
code = codes.Unknown
|
||||
}
|
||||
return status.Error(code, http.StatusText(*(d.httpStatus)))
|
||||
}
|
||||
|
||||
// gRPC status doesn't exist and http status is OK.
|
||||
// Set rawStatusCode to be unknown and return nil error.
|
||||
// So that, if the stream has ended this Unknown status
|
||||
// will be propagated to the user.
|
||||
// Otherwise, it will be ignored. In which case, status from
|
||||
// a later trailer, that has StreamEnded flag set, is propagated.
|
||||
code := int(codes.Unknown)
|
||||
d.rawStatusCode = &code
|
||||
return nil
|
||||
return status.Error(code, d.constructHTTPErrMsg())
|
||||
}
|
||||
|
||||
// constructErrMsg constructs error message to be returned in HTTP fallback mode.
|
||||
// Format: HTTP status code and its corresponding message + content-type error message.
|
||||
func (d *decodeState) constructHTTPErrMsg() string {
|
||||
var errMsgs []string
|
||||
|
||||
if d.data.httpStatus == nil {
|
||||
errMsgs = append(errMsgs, "malformed header: missing HTTP status")
|
||||
} else {
|
||||
errMsgs = append(errMsgs, fmt.Sprintf("%s: HTTP status code %d", http.StatusText(*(d.data.httpStatus)), *d.data.httpStatus))
|
||||
}
|
||||
|
||||
if d.data.contentTypeErr == "" {
|
||||
errMsgs = append(errMsgs, "transport: missing content-type field")
|
||||
} else {
|
||||
errMsgs = append(errMsgs, d.data.contentTypeErr)
|
||||
}
|
||||
|
||||
return strings.Join(errMsgs, "; ")
|
||||
}
|
||||
|
||||
func (d *decodeState) addMetadata(k, v string) {
|
||||
if d.mdata == nil {
|
||||
d.mdata = make(map[string][]string)
|
||||
if d.data.mdata == nil {
|
||||
d.data.mdata = make(map[string][]string)
|
||||
}
|
||||
d.mdata[k] = append(d.mdata[k], v)
|
||||
d.data.mdata[k] = append(d.data.mdata[k], v)
|
||||
}
|
||||
|
||||
func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
|
||||
func (d *decodeState) processHeaderField(f hpack.HeaderField) {
|
||||
switch f.Name {
|
||||
case "content-type":
|
||||
contentSubtype, validContentType := contentSubtype(f.Value)
|
||||
if !validContentType {
|
||||
return status.Errorf(codes.Internal, "transport: received the unexpected content-type %q", f.Value)
|
||||
d.data.contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", f.Value)
|
||||
return
|
||||
}
|
||||
d.contentSubtype = contentSubtype
|
||||
d.data.contentSubtype = contentSubtype
|
||||
// TODO: do we want to propagate the whole content-type in the metadata,
|
||||
// or come up with a way to just propagate the content-subtype if it was set?
|
||||
// ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"}
|
||||
// in the metadata?
|
||||
d.addMetadata(f.Name, f.Value)
|
||||
d.data.isGRPC = true
|
||||
case "grpc-encoding":
|
||||
d.encoding = f.Value
|
||||
d.data.encoding = f.Value
|
||||
case "grpc-status":
|
||||
code, err := strconv.Atoi(f.Value)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err)
|
||||
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err)
|
||||
return
|
||||
}
|
||||
d.rawStatusCode = &code
|
||||
d.data.rawStatusCode = &code
|
||||
case "grpc-message":
|
||||
d.rawStatusMsg = decodeGrpcMessage(f.Value)
|
||||
d.data.rawStatusMsg = decodeGrpcMessage(f.Value)
|
||||
case "grpc-status-details-bin":
|
||||
v, err := decodeBinHeader(f.Value)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
|
||||
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
|
||||
return
|
||||
}
|
||||
s := &spb.Status{}
|
||||
if err := proto.Unmarshal(v, s); err != nil {
|
||||
return status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
|
||||
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err)
|
||||
return
|
||||
}
|
||||
d.statusGen = status.FromProto(s)
|
||||
d.data.statusGen = status.FromProto(s)
|
||||
case "grpc-timeout":
|
||||
d.timeoutSet = true
|
||||
d.data.timeoutSet = true
|
||||
var err error
|
||||
if d.timeout, err = decodeTimeout(f.Value); err != nil {
|
||||
return status.Errorf(codes.Internal, "transport: malformed time-out: %v", err)
|
||||
if d.data.timeout, err = decodeTimeout(f.Value); err != nil {
|
||||
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed time-out: %v", err)
|
||||
}
|
||||
case ":path":
|
||||
d.method = f.Value
|
||||
d.data.method = f.Value
|
||||
case ":status":
|
||||
code, err := strconv.Atoi(f.Value)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "transport: malformed http-status: %v", err)
|
||||
d.data.httpErr = status.Errorf(codes.Internal, "transport: malformed http-status: %v", err)
|
||||
return
|
||||
}
|
||||
d.httpStatus = &code
|
||||
d.data.httpStatus = &code
|
||||
case "grpc-tags-bin":
|
||||
v, err := decodeBinHeader(f.Value)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err)
|
||||
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err)
|
||||
return
|
||||
}
|
||||
d.statsTags = v
|
||||
d.data.statsTags = v
|
||||
d.addMetadata(f.Name, string(v))
|
||||
case "grpc-trace-bin":
|
||||
v, err := decodeBinHeader(f.Value)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err)
|
||||
d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err)
|
||||
return
|
||||
}
|
||||
d.statsTrace = v
|
||||
d.data.statsTrace = v
|
||||
d.addMetadata(f.Name, string(v))
|
||||
default:
|
||||
if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) {
|
||||
|
@ -359,11 +421,10 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error {
|
|||
v, err := decodeMetadataHeader(f.Name, f.Value)
|
||||
if err != nil {
|
||||
errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err)
|
||||
return nil
|
||||
return
|
||||
}
|
||||
d.addMetadata(f.Name, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type timeoutUnit uint8
|
||||
|
|
|
@ -327,8 +327,7 @@ func (s *Stream) TrailersOnly() (bool, error) {
|
|||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// if !headerDone, some other connection error occurred.
|
||||
return s.noHeaders && atomic.LoadUint32(&s.headerDone) == 1, nil
|
||||
return s.noHeaders, nil
|
||||
}
|
||||
|
||||
// Trailer returns the cached trailer metedata. Note that if it is not called
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
package transport
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
|
@ -28,7 +27,6 @@ import (
|
|||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
@ -1943,167 +1941,6 @@ func waitWhileTrue(t *testing.T, condition func() (bool, error)) {
|
|||
}
|
||||
}
|
||||
|
||||
// A function of type writeHeaders writes out
|
||||
// http status with the given stream ID using the given framer.
|
||||
type writeHeaders func(*http2.Framer, uint32, int) error
|
||||
|
||||
func writeOneHeader(framer *http2.Framer, sid uint32, httpStatus int) error {
|
||||
var buf bytes.Buffer
|
||||
henc := hpack.NewEncoder(&buf)
|
||||
henc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(httpStatus)})
|
||||
return framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: sid,
|
||||
BlockFragment: buf.Bytes(),
|
||||
EndStream: true,
|
||||
EndHeaders: true,
|
||||
})
|
||||
}
|
||||
|
||||
func writeTwoHeaders(framer *http2.Framer, sid uint32, httpStatus int) error {
|
||||
var buf bytes.Buffer
|
||||
henc := hpack.NewEncoder(&buf)
|
||||
henc.WriteField(hpack.HeaderField{
|
||||
Name: ":status",
|
||||
Value: fmt.Sprint(http.StatusOK),
|
||||
})
|
||||
if err := framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: sid,
|
||||
BlockFragment: buf.Bytes(),
|
||||
EndHeaders: true,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
buf.Reset()
|
||||
henc.WriteField(hpack.HeaderField{
|
||||
Name: ":status",
|
||||
Value: fmt.Sprint(httpStatus),
|
||||
})
|
||||
return framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: sid,
|
||||
BlockFragment: buf.Bytes(),
|
||||
EndStream: true,
|
||||
EndHeaders: true,
|
||||
})
|
||||
}
|
||||
|
||||
type httpServer struct {
|
||||
httpStatus int
|
||||
wh writeHeaders
|
||||
}
|
||||
|
||||
func (s *httpServer) start(t *testing.T, lis net.Listener) {
|
||||
// Launch an HTTP server to send back header with httpStatus.
|
||||
go func() {
|
||||
conn, err := lis.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Error accepting connection: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
// Read preface sent by client.
|
||||
if _, err = io.ReadFull(conn, make([]byte, len(http2.ClientPreface))); err != nil {
|
||||
t.Errorf("Error at server-side while reading preface from client. Err: %v", err)
|
||||
return
|
||||
}
|
||||
reader := bufio.NewReaderSize(conn, defaultWriteBufSize)
|
||||
writer := bufio.NewWriterSize(conn, defaultReadBufSize)
|
||||
framer := http2.NewFramer(writer, reader)
|
||||
if err = framer.WriteSettingsAck(); err != nil {
|
||||
t.Errorf("Error at server-side while sending Settings ack. Err: %v", err)
|
||||
return
|
||||
}
|
||||
var sid uint32
|
||||
// Read frames until a header is received.
|
||||
for {
|
||||
frame, err := framer.ReadFrame()
|
||||
if err != nil {
|
||||
t.Errorf("Error at server-side while reading frame. Err: %v", err)
|
||||
return
|
||||
}
|
||||
if hframe, ok := frame.(*http2.HeadersFrame); ok {
|
||||
sid = hframe.Header().StreamID
|
||||
break
|
||||
}
|
||||
}
|
||||
if err = s.wh(framer, sid, s.httpStatus); err != nil {
|
||||
t.Errorf("Error at server-side while writing headers. Err: %v", err)
|
||||
return
|
||||
}
|
||||
writer.Flush()
|
||||
}()
|
||||
}
|
||||
|
||||
func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (*Stream, func()) {
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen. Err: %v", err)
|
||||
}
|
||||
server := &httpServer{
|
||||
httpStatus: httpStatus,
|
||||
wh: wh,
|
||||
}
|
||||
server.start(t, lis)
|
||||
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
|
||||
defer cancel()
|
||||
client, err := newHTTP2Client(connectCtx, context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}, func() {}, func(GoAwayReason) {}, func() {})
|
||||
if err != nil {
|
||||
lis.Close()
|
||||
t.Fatalf("Error creating client. Err: %v", err)
|
||||
}
|
||||
stream, err := client.NewStream(context.Background(), &CallHdr{Method: "bogus/method"})
|
||||
if err != nil {
|
||||
client.Close()
|
||||
lis.Close()
|
||||
t.Fatalf("Error creating stream at client-side. Err: %v", err)
|
||||
}
|
||||
return stream, func() {
|
||||
client.Close()
|
||||
lis.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPToGRPCStatusMapping(t *testing.T) {
|
||||
for k := range httpStatusConvTab {
|
||||
testHTTPToGRPCStatusMapping(t, k, writeOneHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders) {
|
||||
stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh)
|
||||
defer cleanUp()
|
||||
want := httpStatusConvTab[httpStatus]
|
||||
buf := make([]byte, 8)
|
||||
_, err := stream.Read(buf)
|
||||
if err == nil {
|
||||
t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want)
|
||||
}
|
||||
serr, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("err.(Type) = %T, want status error", err)
|
||||
}
|
||||
if want != serr.Code() {
|
||||
t.Fatalf("Want error code: %v, got: %v", want, serr.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) {
|
||||
stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader)
|
||||
defer cleanUp()
|
||||
buf := make([]byte, 8)
|
||||
_, err := stream.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err)
|
||||
}
|
||||
want := codes.Unknown
|
||||
if stream.status.Code() != want {
|
||||
t.Fatalf("Status code of stream: %v, want: %v", stream.status.Code(), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) {
|
||||
testHTTPToGRPCStatusMapping(t, http.StatusUnauthorized, writeTwoHeaders)
|
||||
}
|
||||
|
||||
// If any error occurs on a call to Stream.Read, future calls
|
||||
// should continue to return that same error.
|
||||
func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
|
@ -45,6 +46,7 @@ import (
|
|||
"github.com/golang/protobuf/proto"
|
||||
anypb "github.com/golang/protobuf/ptypes/any"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
spb "google.golang.org/genproto/googleapis/rpc/status"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer/roundrobin"
|
||||
|
@ -62,6 +64,7 @@ import (
|
|||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/internal/leakcheck"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/internal/transport"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/resolver"
|
||||
|
@ -6892,7 +6895,7 @@ func testClientMaxHeaderListSizeServerIntentionalViolation(t *testing.T, e env)
|
|||
time.Sleep(100 * time.Millisecond)
|
||||
rcw.writeHeaders(http2.HeadersFrameParam{
|
||||
StreamID: tc.getCurrentStreamID(),
|
||||
BlockFragment: rcw.encodeHeader("oversize", strings.Join(val, "")),
|
||||
BlockFragment: rcw.encodeRawHeader("oversize", strings.Join(val, "")),
|
||||
EndStream: false,
|
||||
EndHeaders: true,
|
||||
})
|
||||
|
@ -7125,3 +7128,221 @@ func (s) TestRPCWaitsForResolver(t *testing.T) {
|
|||
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestHTTPHeaderFrameErrorHandlingHTTPMode(t *testing.T) {
|
||||
// Non-gRPC content-type fallback path.
|
||||
for httpCode := range transport.HTTPStatusConvTab {
|
||||
doHTTPHeaderTest(t, transport.HTTPStatusConvTab[int(httpCode)], []string{
|
||||
":status", fmt.Sprintf("%d", httpCode),
|
||||
"content-type", "text/html", // non-gRPC content type to switch to HTTP mode.
|
||||
"grpc-status", "1", // Make up a gRPC status error
|
||||
"grpc-status-details-bin", "???", // Make up a gRPC field parsing error
|
||||
})
|
||||
}
|
||||
|
||||
// Missing content-type fallback path.
|
||||
for httpCode := range transport.HTTPStatusConvTab {
|
||||
doHTTPHeaderTest(t, transport.HTTPStatusConvTab[int(httpCode)], []string{
|
||||
":status", fmt.Sprintf("%d", httpCode),
|
||||
// Omitting content type to switch to HTTP mode.
|
||||
"grpc-status", "1", // Make up a gRPC status error
|
||||
"grpc-status-details-bin", "???", // Make up a gRPC field parsing error
|
||||
})
|
||||
}
|
||||
|
||||
// Malformed HTTP status when fallback.
|
||||
doHTTPHeaderTest(t, codes.Internal, []string{
|
||||
":status", "abc",
|
||||
// Omitting content type to switch to HTTP mode.
|
||||
"grpc-status", "1", // Make up a gRPC status error
|
||||
"grpc-status-details-bin", "???", // Make up a gRPC field parsing error
|
||||
})
|
||||
}
|
||||
|
||||
// Testing erroneous ReponseHeader or Trailers-only (delivered in the first HEADERS frame).
|
||||
func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
header []string
|
||||
errCode codes.Code
|
||||
}{
|
||||
{
|
||||
// missing gRPC status.
|
||||
header: []string{
|
||||
":status", "403",
|
||||
"content-type", "application/grpc",
|
||||
},
|
||||
errCode: codes.Unknown,
|
||||
},
|
||||
{
|
||||
// malformed grpc-status.
|
||||
header: []string{
|
||||
":status", "502",
|
||||
"content-type", "application/grpc",
|
||||
"grpc-status", "abc",
|
||||
},
|
||||
errCode: codes.Internal,
|
||||
},
|
||||
{
|
||||
// Malformed grpc-tags-bin field.
|
||||
header: []string{
|
||||
":status", "502",
|
||||
"content-type", "application/grpc",
|
||||
"grpc-status", "0",
|
||||
"grpc-tags-bin", "???",
|
||||
},
|
||||
errCode: codes.Internal,
|
||||
},
|
||||
{
|
||||
// gRPC status error.
|
||||
header: []string{
|
||||
":status", "502",
|
||||
"content-type", "application/grpc",
|
||||
"grpc-status", "3",
|
||||
},
|
||||
errCode: codes.InvalidArgument,
|
||||
},
|
||||
} {
|
||||
doHTTPHeaderTest(t, test.errCode, test.header)
|
||||
}
|
||||
}
|
||||
|
||||
// Testing non-Trailers-only Trailers (delievered in second HEADERS frame)
|
||||
func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
responseHeader []string
|
||||
trailer []string
|
||||
errCode codes.Code
|
||||
}{
|
||||
{
|
||||
responseHeader: []string{
|
||||
":status", "200",
|
||||
"content-type", "application/grpc",
|
||||
},
|
||||
trailer: []string{
|
||||
// trailer missing grpc-status
|
||||
":status", "502",
|
||||
},
|
||||
errCode: codes.Unknown,
|
||||
},
|
||||
{
|
||||
responseHeader: []string{
|
||||
":status", "404",
|
||||
"content-type", "application/grpc",
|
||||
},
|
||||
trailer: []string{
|
||||
// malformed grpc-status-details-bin field
|
||||
"grpc-status", "0",
|
||||
"grpc-status-details-bin", "????",
|
||||
},
|
||||
errCode: codes.Internal,
|
||||
},
|
||||
} {
|
||||
doHTTPHeaderTest(t, test.errCode, test.responseHeader, test.trailer)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestHTTPHeaderFrameErrorHandlingMoreThanTwoHeaders(t *testing.T) {
|
||||
header := []string{
|
||||
":status", "200",
|
||||
"content-type", "application/grpc",
|
||||
}
|
||||
doHTTPHeaderTest(t, codes.Internal, header, header, header)
|
||||
}
|
||||
|
||||
type httpServer struct {
|
||||
headerFields [][]string
|
||||
}
|
||||
|
||||
func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error {
|
||||
if len(headerFields)%2 == 1 {
|
||||
panic("odd number of kv args")
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
henc := hpack.NewEncoder(&buf)
|
||||
for len(headerFields) > 0 {
|
||||
k, v := headerFields[0], headerFields[1]
|
||||
headerFields = headerFields[2:]
|
||||
henc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||
}
|
||||
|
||||
return framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: sid,
|
||||
BlockFragment: buf.Bytes(),
|
||||
EndStream: endStream,
|
||||
EndHeaders: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *httpServer) start(t *testing.T, lis net.Listener) {
|
||||
// Launch an HTTP server to send back header.
|
||||
go func() {
|
||||
conn, err := lis.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Error accepting connection: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
// Read preface sent by client.
|
||||
if _, err = io.ReadFull(conn, make([]byte, len(http2.ClientPreface))); err != nil {
|
||||
t.Errorf("Error at server-side while reading preface from client. Err: %v", err)
|
||||
return
|
||||
}
|
||||
reader := bufio.NewReader(conn)
|
||||
writer := bufio.NewWriter(conn)
|
||||
framer := http2.NewFramer(writer, reader)
|
||||
if err = framer.WriteSettingsAck(); err != nil {
|
||||
t.Errorf("Error at server-side while sending Settings ack. Err: %v", err)
|
||||
return
|
||||
}
|
||||
writer.Flush() // necessary since client is expecting preface before declaring connection fully setup.
|
||||
|
||||
var sid uint32
|
||||
// Read frames until a header is received.
|
||||
for {
|
||||
frame, err := framer.ReadFrame()
|
||||
if err != nil {
|
||||
t.Errorf("Error at server-side while reading frame. Err: %v", err)
|
||||
return
|
||||
}
|
||||
if hframe, ok := frame.(*http2.HeadersFrame); ok {
|
||||
sid = hframe.Header().StreamID
|
||||
break
|
||||
}
|
||||
}
|
||||
for i, headers := range s.headerFields {
|
||||
if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil {
|
||||
t.Errorf("Error at server-side while writing headers. Err: %v", err)
|
||||
return
|
||||
}
|
||||
writer.Flush()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func doHTTPHeaderTest(t *testing.T, errCode codes.Code, headerFields ...[]string) {
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen. Err: %v", err)
|
||||
}
|
||||
defer lis.Close()
|
||||
server := &httpServer{
|
||||
headerFields: headerFields,
|
||||
}
|
||||
server.start(t, lis)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
cc, err := grpc.DialContext(ctx, lis.Addr().String(), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial due to err: %v", err)
|
||||
}
|
||||
defer cc.Close()
|
||||
client := testpb.NewTestServiceClient(cc)
|
||||
stream, err := client.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("error creating stream due to err: %v", err)
|
||||
}
|
||||
if _, err := stream.Recv(); err == nil || status.Code(err) != errCode {
|
||||
t.Fatalf("stream.Recv() = _, %v, want error code: %v", err, errCode)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -227,6 +227,47 @@ func (rcw *rawConnWrapper) encodeHeaderField(k, v string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// encodeRawHeader is for usage on both client and server side to construct header based on the input
|
||||
// key, value pairs.
|
||||
func (rcw *rawConnWrapper) encodeRawHeader(headers ...string) []byte {
|
||||
if len(headers)%2 == 1 {
|
||||
panic("odd number of kv args")
|
||||
}
|
||||
|
||||
rcw.headerBuf.Reset()
|
||||
|
||||
pseudoCount := map[string]int{}
|
||||
var keys []string
|
||||
vals := map[string][]string{}
|
||||
|
||||
for len(headers) > 0 {
|
||||
k, v := headers[0], headers[1]
|
||||
headers = headers[2:]
|
||||
if _, ok := vals[k]; !ok {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
if strings.HasPrefix(k, ":") {
|
||||
pseudoCount[k]++
|
||||
if pseudoCount[k] == 1 {
|
||||
vals[k] = []string{v}
|
||||
} else {
|
||||
// Allows testing of invalid headers w/ dup pseudo fields.
|
||||
vals[k] = append(vals[k], v)
|
||||
}
|
||||
} else {
|
||||
vals[k] = append(vals[k], v)
|
||||
}
|
||||
}
|
||||
for _, k := range keys {
|
||||
for _, v := range vals[k] {
|
||||
rcw.encodeHeaderField(k, v)
|
||||
}
|
||||
}
|
||||
return rcw.headerBuf.Bytes()
|
||||
}
|
||||
|
||||
// encodeHeader is for usage on client side to write request header.
|
||||
//
|
||||
// encodeHeader encodes headers and returns their HPACK bytes. headers
|
||||
// must contain an even number of key/value pairs. There may be
|
||||
// multiple pairs for keys (e.g. "cookie"). The :method, :path, and
|
||||
|
@ -288,6 +329,7 @@ func (rcw *rawConnWrapper) encodeHeader(headers ...string) []byte {
|
|||
return rcw.headerBuf.Bytes()
|
||||
}
|
||||
|
||||
// writeHeadersGRPC is for usage on client side to write request header.
|
||||
func (rcw *rawConnWrapper) writeHeadersGRPC(streamID uint32, path string) {
|
||||
rcw.writeHeaders(http2.HeadersFrameParam{
|
||||
StreamID: streamID,
|
||||
|
|
Loading…
Reference in New Issue