mirror of https://github.com/grpc/grpc-go.git
Fix crashes where transports returned errors unhandled by the message parser.
The http.Handler-based transport body reader was returning error types not understood by the recvMsg parser. See #557 for some background and examples. Fix the http.Handler transport and add tests. I copied in a subset of the http2 package's serverTest type, adapted slightly to work with grpc. In the process of adding tests, I discovered that ErrUnexpectedEOF was also not handled by the regular server transport. Document the rules and fix that crash as well. Unrelated stuff in this CL: * make tests listen on localhost:0 instead of :0, to avoid Mac firewall pop-up dialogs. * rename parser.s field to parser.r, to be more idiomatic that it's an io.Reader and not anything fancier. (it's not acting like type stream, even if that's the typical concrete type) * move 5 byte temp buffer into parser, rather than allocating it for each new message. (drop in the bucket improvement in garbage; more to do later) * rename http2RSTErrConvTab to http2ErrConvTab, per Qi's earlier CL. Also add the HTTP/1.1-required error mapping for completeness, not that it should ever arise with gRPC, also per Qi's earlier CL referenced in #557.
This commit is contained in:
parent
178b68e281
commit
110fd99e30
2
call.go
2
call.go
|
@ -55,7 +55,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p := &parser{s: stream}
|
||||
p := &parser{r: stream}
|
||||
for {
|
||||
if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil {
|
||||
if err == io.EOF {
|
||||
|
|
|
@ -75,7 +75,7 @@ type testStreamHandler struct {
|
|||
}
|
||||
|
||||
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
|
||||
p := &parser{s: s}
|
||||
p := &parser{r: s}
|
||||
for {
|
||||
pf, req, err := p.recvMsg()
|
||||
if err == io.EOF {
|
||||
|
@ -125,9 +125,9 @@ func newTestServer() *server {
|
|||
func (s *server) start(t *testing.T, port int, maxStreams uint32) {
|
||||
var err error
|
||||
if port == 0 {
|
||||
s.lis, err = net.Listen("tcp", ":0")
|
||||
s.lis, err = net.Listen("tcp", "localhost:0")
|
||||
} else {
|
||||
s.lis, err = net.Listen("tcp", ":"+strconv.Itoa(port))
|
||||
s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port))
|
||||
}
|
||||
if err != nil {
|
||||
s.startedErr <- fmt.Errorf("failed to listen: %v", err)
|
||||
|
|
38
rpc_util.go
38
rpc_util.go
|
@ -191,30 +191,44 @@ const (
|
|||
|
||||
// parser reads complelete gRPC messages from the underlying reader.
|
||||
type parser struct {
|
||||
s io.Reader
|
||||
}
|
||||
// r is the underlying reader.
|
||||
// See the comment on recvMsg for the permissible
|
||||
// error types.
|
||||
r io.Reader
|
||||
|
||||
// recvMsg is to read a complete gRPC message from the stream. It is blocking if
|
||||
// the message has not been complete yet. It returns the message and its type,
|
||||
// EOF is returned with nil msg and 0 pf if the entire stream is done. Other
|
||||
// non-nil error is returned if something is wrong on reading.
|
||||
func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
|
||||
// The header of a gRPC message. Find more detail
|
||||
// at http://www.grpc.io/docs/guides/wire.html.
|
||||
var buf [5]byte
|
||||
header [5]byte
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(p.s, buf[:]); err != nil {
|
||||
// recvMsg reads a complete gRPC message from the stream.
|
||||
//
|
||||
// It returns the message and its payload (compression/encoding)
|
||||
// format. The caller owns the returned msg memory.
|
||||
//
|
||||
// If there is an error, possible values are:
|
||||
// * io.EOF, when no messages remain
|
||||
// * io.ErrUnexpectedEOF
|
||||
// * of type transport.ConnectionError
|
||||
// * of type transport.StreamError
|
||||
// No other error values or types must be returned, which also means
|
||||
// that the underlying io.Reader must not return an incompatible
|
||||
// error.
|
||||
func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
|
||||
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
pf = payloadFormat(buf[0])
|
||||
length := binary.BigEndian.Uint32(buf[1:])
|
||||
pf = payloadFormat(p.header[0])
|
||||
length := binary.BigEndian.Uint32(p.header[1:])
|
||||
|
||||
if length == 0 {
|
||||
return pf, nil, nil
|
||||
}
|
||||
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
|
||||
// of making it for each message:
|
||||
msg = make([]byte, int(length))
|
||||
if _, err := io.ReadFull(p.s, msg); err != nil {
|
||||
if _, err := io.ReadFull(p.r, msg); err != nil {
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
|
|
|
@ -65,7 +65,7 @@ func TestSimpleParsing(t *testing.T) {
|
|||
{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
|
||||
} {
|
||||
buf := bytes.NewReader(test.p)
|
||||
parser := &parser{buf}
|
||||
parser := &parser{r: buf}
|
||||
pt, b, err := parser.recvMsg()
|
||||
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
|
||||
t.Fatalf("parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
|
||||
|
@ -77,7 +77,7 @@ func TestMultipleParsing(t *testing.T) {
|
|||
// Set a byte stream consists of 3 messages with their headers.
|
||||
p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
|
||||
b := bytes.NewReader(p)
|
||||
parser := &parser{b}
|
||||
parser := &parser{r: b}
|
||||
|
||||
wantRecvs := []struct {
|
||||
pt payloadFormat
|
||||
|
|
|
@ -446,13 +446,16 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
}
|
||||
}()
|
||||
}
|
||||
p := &parser{s: stream}
|
||||
p := &parser{r: stream}
|
||||
for {
|
||||
pf, req, err := p.recvMsg()
|
||||
if err == io.EOF {
|
||||
// The entire stream is done (for unary RPC only).
|
||||
return err
|
||||
}
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
err = transport.StreamError{Code: codes.Internal, Desc: "io.ErrUnexpectedEOF"}
|
||||
}
|
||||
if err != nil {
|
||||
switch err := err.(type) {
|
||||
case transport.ConnectionError:
|
||||
|
@ -558,7 +561,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||
ss := &serverStream{
|
||||
t: t,
|
||||
s: stream,
|
||||
p: &parser{s: stream},
|
||||
p: &parser{r: stream},
|
||||
codec: s.opts.codec,
|
||||
cp: s.opts.cp,
|
||||
dc: s.opts.dc,
|
||||
|
|
|
@ -109,7 +109,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||
callHdr := &transport.CallHdr{
|
||||
Host: cc.authority,
|
||||
Method: method,
|
||||
Flush: desc.ServerStreams&&desc.ClientStreams,
|
||||
Flush: desc.ServerStreams && desc.ClientStreams,
|
||||
}
|
||||
if cc.dopts.cp != nil {
|
||||
callHdr.SendCompress = cc.dopts.cp.Type()
|
||||
|
@ -141,7 +141,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||
}
|
||||
cs.t = t
|
||||
cs.s = s
|
||||
cs.p = &parser{s: s}
|
||||
cs.p = &parser{r: s}
|
||||
// Listen on ctx.Done() to detect cancellation when there is no pending
|
||||
// I/O operations on this stream.
|
||||
go func() {
|
||||
|
|
|
@ -35,6 +35,8 @@ package grpc_test
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -53,6 +55,7 @@ import (
|
|||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/http2"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
@ -62,6 +65,7 @@ import (
|
|||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
testpb "google.golang.org/grpc/test/grpc_testing"
|
||||
"google.golang.org/grpc/transport"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -290,7 +294,7 @@ func TestReconnectTimeout(t *testing.T) {
|
|||
)
|
||||
defer restore()
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
|
@ -354,6 +358,15 @@ func (e env) runnable() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (e env) getDialer() func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
if e.dialer != nil {
|
||||
return e.dialer
|
||||
}
|
||||
return func(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
return net.DialTimeout("tcp", addr, timeout)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
tcpClearEnv = env{name: "tcp-clear", network: "tcp"}
|
||||
tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"}
|
||||
|
@ -451,7 +464,7 @@ func (te *test) startServer() {
|
|||
)
|
||||
}
|
||||
|
||||
la := ":0"
|
||||
la := "localhost:0"
|
||||
switch e.network {
|
||||
case "unix":
|
||||
la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now())
|
||||
|
@ -530,6 +543,25 @@ func (te *test) declareLogNoise(phrases ...string) {
|
|||
te.restoreLogs = declareLogNoise(te.t, phrases...)
|
||||
}
|
||||
|
||||
func (te *test) withServerTester(fn func(st *serverTester)) {
|
||||
var c net.Conn
|
||||
var err error
|
||||
c, err = te.e.getDialer()(te.srvAddr, 10*time.Second)
|
||||
if err != nil {
|
||||
te.t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
if te.e.security == "tls" {
|
||||
c = tls.Client(c, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
NextProtos: []string{http2.NextProtoTLS},
|
||||
})
|
||||
}
|
||||
st := newServerTesterFromConn(te.t, c)
|
||||
st.greet()
|
||||
fn(st)
|
||||
}
|
||||
|
||||
func TestTimeoutOnDeadServer(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
|
@ -1613,6 +1645,145 @@ func testCompressOK(t *testing.T, e env) {
|
|||
}
|
||||
}
|
||||
|
||||
// funcServer implements methods of TestServiceServer using funcs,
|
||||
// similar to an http.HandlerFunc.
|
||||
// Any unimplemented method will crash. Tests implement the method(s)
|
||||
// they need.
|
||||
type funcServer struct {
|
||||
testpb.TestServiceServer
|
||||
unaryCall func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error)
|
||||
streamingInputCall func(stream testpb.TestService_StreamingInputCallServer) error
|
||||
}
|
||||
|
||||
func (s *funcServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
return s.unaryCall(ctx, in)
|
||||
}
|
||||
|
||||
func (s *funcServer) StreamingInputCall(stream testpb.TestService_StreamingInputCallServer) error {
|
||||
return s.streamingInputCall(stream)
|
||||
}
|
||||
|
||||
func TestClientRequestBodyError_UnexpectedEOF(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
testClientRequestBodyError_UnexpectedEOF(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testClientRequestBodyError_UnexpectedEOF(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
errUnexpectedCall := errors.New("unexpected call func server method")
|
||||
t.Error(errUnexpectedCall)
|
||||
return nil, errUnexpectedCall
|
||||
}}
|
||||
te.startServer()
|
||||
defer te.tearDown()
|
||||
te.withServerTester(func(st *serverTester) {
|
||||
st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall")
|
||||
// Say we have 5 bytes coming, but set END_STREAM flag:
|
||||
st.writeData(1, true, []byte{0, 0, 0, 0, 5})
|
||||
st.wantAnyFrame() // wait for server to crash (it used to crash)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientRequestBodyError_CloseAfterLength(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
testClientRequestBodyError_CloseAfterLength(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testClientRequestBodyError_CloseAfterLength(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.declareLogNoise("Server.processUnaryRPC failed to write status")
|
||||
te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
errUnexpectedCall := errors.New("unexpected call func server method")
|
||||
t.Error(errUnexpectedCall)
|
||||
return nil, errUnexpectedCall
|
||||
}}
|
||||
te.startServer()
|
||||
defer te.tearDown()
|
||||
te.withServerTester(func(st *serverTester) {
|
||||
st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall")
|
||||
// say we're sending 5 bytes, but then close the connection instead.
|
||||
st.writeData(1, false, []byte{0, 0, 0, 0, 5})
|
||||
st.cc.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientRequestBodyError_Cancel(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
testClientRequestBodyError_Cancel(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testClientRequestBodyError_Cancel(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
gotCall := make(chan bool, 1)
|
||||
te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
gotCall <- true
|
||||
return new(testpb.SimpleResponse), nil
|
||||
}}
|
||||
te.startServer()
|
||||
defer te.tearDown()
|
||||
te.withServerTester(func(st *serverTester) {
|
||||
st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall")
|
||||
// Say we have 5 bytes coming, but cancel it instead.
|
||||
st.writeData(1, false, []byte{0, 0, 0, 0, 5})
|
||||
st.writeRSTStream(1, http2.ErrCodeCancel)
|
||||
|
||||
// Verify we didn't a call yet.
|
||||
select {
|
||||
case <-gotCall:
|
||||
t.Fatal("unexpected call")
|
||||
default:
|
||||
}
|
||||
|
||||
// And now send an uncanceled (but still invalid), just to get a response.
|
||||
st.writeHeadersGRPC(3, "/grpc.testing.TestService/UnaryCall")
|
||||
st.writeData(3, true, []byte{0, 0, 0, 0, 0})
|
||||
<-gotCall
|
||||
st.wantAnyFrame()
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientRequestBodyError_Cancel_StreamingInput(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
testClientRequestBodyError_Cancel_StreamingInput(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
recvErr := make(chan error, 1)
|
||||
te.testServer = &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error {
|
||||
_, err := stream.Recv()
|
||||
recvErr <- err
|
||||
return nil
|
||||
}}
|
||||
te.startServer()
|
||||
defer te.tearDown()
|
||||
te.withServerTester(func(st *serverTester) {
|
||||
st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall")
|
||||
// Say we have 5 bytes coming, but cancel it instead.
|
||||
st.writeData(1, false, []byte{0, 0, 0, 0, 5})
|
||||
st.writeRSTStream(1, http2.ErrCodeCancel)
|
||||
|
||||
var got error
|
||||
select {
|
||||
case got = <-recvErr:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for error")
|
||||
}
|
||||
if se, ok := got.(transport.StreamError); !ok || se.Code != codes.Canceled {
|
||||
t.Errorf("error = %#v; want transport.StreamError with code Canceled")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// interestingGoroutines returns all goroutines we care about for the purpose
|
||||
// of leak checking. It excludes testing or runtime ones.
|
||||
func interestingGoroutines() (gs []string) {
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
/*
|
||||
* Copyright 2016, Google Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
*
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following disclaimer
|
||||
* in the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* * Neither the name of Google Inc. nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
package grpc_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
// This is a subset of http2's serverTester type.
|
||||
//
|
||||
// serverTester wraps a io.ReadWriter (acting like the underlying
|
||||
// network connection) and provides utility methods to read and write
|
||||
// http2 frames.
|
||||
//
|
||||
// NOTE(bradfitz): this could eventually be exported somewhere. Others
|
||||
// have asked for it too. For now I'm still experimenting with the
|
||||
// API and don't feel like maintaining a stable testing API.
|
||||
|
||||
type serverTester struct {
|
||||
cc io.ReadWriteCloser // client conn
|
||||
t testing.TB
|
||||
fr *http2.Framer
|
||||
|
||||
// writing headers:
|
||||
headerBuf bytes.Buffer
|
||||
hpackEnc *hpack.Encoder
|
||||
|
||||
// reading frames:
|
||||
frc chan http2.Frame
|
||||
frErrc chan error
|
||||
readTimer *time.Timer
|
||||
}
|
||||
|
||||
func newServerTesterFromConn(t testing.TB, cc io.ReadWriteCloser) *serverTester {
|
||||
st := &serverTester{
|
||||
t: t,
|
||||
cc: cc,
|
||||
frc: make(chan http2.Frame, 1),
|
||||
frErrc: make(chan error, 1),
|
||||
}
|
||||
st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
|
||||
st.fr = http2.NewFramer(cc, cc)
|
||||
st.fr.ReadMetaHeaders = hpack.NewDecoder(4096 /*initialHeaderTableSize*/, nil)
|
||||
|
||||
return st
|
||||
}
|
||||
|
||||
func (st *serverTester) readFrame() (http2.Frame, error) {
|
||||
go func() {
|
||||
fr, err := st.fr.ReadFrame()
|
||||
if err != nil {
|
||||
st.frErrc <- err
|
||||
} else {
|
||||
st.frc <- fr
|
||||
}
|
||||
}()
|
||||
t := time.NewTimer(2 * time.Second)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case f := <-st.frc:
|
||||
return f, nil
|
||||
case err := <-st.frErrc:
|
||||
return nil, err
|
||||
case <-t.C:
|
||||
return nil, errors.New("timeout waiting for frame")
|
||||
}
|
||||
}
|
||||
|
||||
// greet initiates the client's HTTP/2 connection into a state where
|
||||
// frames may be sent.
|
||||
func (st *serverTester) greet() {
|
||||
st.writePreface()
|
||||
st.writeInitialSettings()
|
||||
st.wantSettings()
|
||||
st.writeSettingsAck()
|
||||
for {
|
||||
f, err := st.readFrame()
|
||||
if err != nil {
|
||||
st.t.Fatal(err)
|
||||
}
|
||||
switch f := f.(type) {
|
||||
case *http2.WindowUpdateFrame:
|
||||
// grpc's transport/http2_server sends this
|
||||
// before the settings ack. The Go http2
|
||||
// server uses a setting instead.
|
||||
case *http2.SettingsFrame:
|
||||
if f.IsAck() {
|
||||
return
|
||||
}
|
||||
st.t.Fatalf("during greet, got non-ACK settings frame")
|
||||
default:
|
||||
st.t.Fatalf("during greet, unexpected frame type %T", f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (st *serverTester) writePreface() {
|
||||
n, err := st.cc.Write([]byte(http2.ClientPreface))
|
||||
if err != nil {
|
||||
st.t.Fatalf("Error writing client preface: %v", err)
|
||||
}
|
||||
if n != len(http2.ClientPreface) {
|
||||
st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(http2.ClientPreface))
|
||||
}
|
||||
}
|
||||
|
||||
func (st *serverTester) writeInitialSettings() {
|
||||
if err := st.fr.WriteSettings(); err != nil {
|
||||
st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (st *serverTester) writeSettingsAck() {
|
||||
if err := st.fr.WriteSettingsAck(); err != nil {
|
||||
st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (st *serverTester) wantSettings() *http2.SettingsFrame {
|
||||
f, err := st.readFrame()
|
||||
if err != nil {
|
||||
st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
|
||||
}
|
||||
sf, ok := f.(*http2.SettingsFrame)
|
||||
if !ok {
|
||||
st.t.Fatalf("got a %T; want *SettingsFrame", f)
|
||||
}
|
||||
return sf
|
||||
}
|
||||
|
||||
func (st *serverTester) wantSettingsAck() {
|
||||
f, err := st.readFrame()
|
||||
if err != nil {
|
||||
st.t.Fatal(err)
|
||||
}
|
||||
sf, ok := f.(*http2.SettingsFrame)
|
||||
if !ok {
|
||||
st.t.Fatalf("Wanting a settings ACK, received a %T", f)
|
||||
}
|
||||
if !sf.IsAck() {
|
||||
st.t.Fatal("Settings Frame didn't have ACK set")
|
||||
}
|
||||
}
|
||||
|
||||
// wait for any activity from the server
|
||||
func (st *serverTester) wantAnyFrame() http2.Frame {
|
||||
f, err := st.fr.ReadFrame()
|
||||
if err != nil {
|
||||
st.t.Fatal(err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (st *serverTester) encodeHeaderField(k, v string) {
|
||||
err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||
if err != nil {
|
||||
st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
|
||||
}
|
||||
}
|
||||
|
||||
// encodeHeader encodes headers and returns their HPACK bytes. headers
|
||||
// must contain an even number of key/value pairs. There may be
|
||||
// multiple pairs for keys (e.g. "cookie"). The :method, :path, and
|
||||
// :scheme headers default to GET, / and https.
|
||||
func (st *serverTester) encodeHeader(headers ...string) []byte {
|
||||
if len(headers)%2 == 1 {
|
||||
panic("odd number of kv args")
|
||||
}
|
||||
|
||||
st.headerBuf.Reset()
|
||||
|
||||
if len(headers) == 0 {
|
||||
// Fast path, mostly for benchmarks, so test code doesn't pollute
|
||||
// profiles when we're looking to improve server allocations.
|
||||
st.encodeHeaderField(":method", "GET")
|
||||
st.encodeHeaderField(":path", "/")
|
||||
st.encodeHeaderField(":scheme", "https")
|
||||
return st.headerBuf.Bytes()
|
||||
}
|
||||
|
||||
if len(headers) == 2 && headers[0] == ":method" {
|
||||
// Another fast path for benchmarks.
|
||||
st.encodeHeaderField(":method", headers[1])
|
||||
st.encodeHeaderField(":path", "/")
|
||||
st.encodeHeaderField(":scheme", "https")
|
||||
return st.headerBuf.Bytes()
|
||||
}
|
||||
|
||||
pseudoCount := map[string]int{}
|
||||
keys := []string{":method", ":path", ":scheme"}
|
||||
vals := map[string][]string{
|
||||
":method": {"GET"},
|
||||
":path": {"/"},
|
||||
":scheme": {"https"},
|
||||
}
|
||||
for len(headers) > 0 {
|
||||
k, v := headers[0], headers[1]
|
||||
headers = headers[2:]
|
||||
if _, ok := vals[k]; !ok {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
if strings.HasPrefix(k, ":") {
|
||||
pseudoCount[k]++
|
||||
if pseudoCount[k] == 1 {
|
||||
vals[k] = []string{v}
|
||||
} else {
|
||||
// Allows testing of invalid headers w/ dup pseudo fields.
|
||||
vals[k] = append(vals[k], v)
|
||||
}
|
||||
} else {
|
||||
vals[k] = append(vals[k], v)
|
||||
}
|
||||
}
|
||||
for _, k := range keys {
|
||||
for _, v := range vals[k] {
|
||||
st.encodeHeaderField(k, v)
|
||||
}
|
||||
}
|
||||
return st.headerBuf.Bytes()
|
||||
}
|
||||
|
||||
func (st *serverTester) writeHeadersGRPC(streamID uint32, path string) {
|
||||
st.writeHeaders(http2.HeadersFrameParam{
|
||||
StreamID: streamID,
|
||||
BlockFragment: st.encodeHeader(
|
||||
":method", "POST",
|
||||
":path", path,
|
||||
"content-type", "application/grpc",
|
||||
"te", "trailers",
|
||||
),
|
||||
EndStream: false,
|
||||
EndHeaders: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (st *serverTester) writeHeaders(p http2.HeadersFrameParam) {
|
||||
if err := st.fr.WriteHeaders(p); err != nil {
|
||||
st.t.Fatalf("Error writing HEADERS: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
|
||||
if err := st.fr.WriteData(streamID, endStream, data); err != nil {
|
||||
st.t.Fatalf("Error writing DATA: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (st *serverTester) writeRSTStream(streamID uint32, code http2.ErrCode) {
|
||||
if err := st.fr.WriteRSTStream(streamID, code); err != nil {
|
||||
st.t.Fatalf("Error writing RST_STREAM: %v", err)
|
||||
}
|
||||
}
|
|
@ -40,6 +40,7 @@ package transport
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
@ -319,7 +320,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
|
|||
s.buf.put(&recvMsg{data: buf[:n]})
|
||||
}
|
||||
if err != nil {
|
||||
s.buf.put(&recvMsg{err: err})
|
||||
s.buf.put(&recvMsg{err: mapRecvMsgError(err)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -352,3 +353,25 @@ func (ht *serverHandlerTransport) runStream() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mapRecvMsgError returns the non-nil err into the appropriate
|
||||
// error value as expected by callers of *grpc.parser.recvMsg.
|
||||
// In particular, in can only be:
|
||||
// * io.EOF
|
||||
// * io.ErrUnexpectedEOF
|
||||
// * of type transport.ConnectionError
|
||||
// * of type transport.StreamError
|
||||
func mapRecvMsgError(err error) error {
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
return err
|
||||
}
|
||||
if se, ok := err.(http2.StreamError); ok {
|
||||
if code, ok := http2ErrConvTab[se.Code]; ok {
|
||||
return StreamError{
|
||||
Code: code,
|
||||
Desc: se.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
return ConnectionError{Desc: err.Error()}
|
||||
}
|
||||
|
|
|
@ -637,7 +637,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
|
|||
close(s.headerChan)
|
||||
s.headerDone = true
|
||||
}
|
||||
s.statusCode, ok = http2RSTErrConvTab[http2.ErrCode(f.ErrCode)]
|
||||
s.statusCode, ok = http2ErrConvTab[http2.ErrCode(f.ErrCode)]
|
||||
if !ok {
|
||||
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
|
||||
}
|
||||
|
|
|
@ -62,8 +62,8 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
clientPreface = []byte(http2.ClientPreface)
|
||||
http2RSTErrConvTab = map[http2.ErrCode]codes.Code{
|
||||
clientPreface = []byte(http2.ClientPreface)
|
||||
http2ErrConvTab = map[http2.ErrCode]codes.Code{
|
||||
http2.ErrCodeNo: codes.Internal,
|
||||
http2.ErrCodeProtocol: codes.Internal,
|
||||
http2.ErrCodeInternal: codes.Internal,
|
||||
|
@ -76,6 +76,7 @@ var (
|
|||
http2.ErrCodeConnect: codes.Internal,
|
||||
http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted,
|
||||
http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
|
||||
http2.ErrCodeHTTP11Required: codes.FailedPrecondition,
|
||||
}
|
||||
statusCodeConvTab = map[codes.Code]http2.ErrCode{
|
||||
codes.Internal: http2.ErrCodeInternal,
|
||||
|
|
|
@ -131,9 +131,9 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
|
|||
func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
|
||||
var err error
|
||||
if port == 0 {
|
||||
s.lis, err = net.Listen("tcp", ":0")
|
||||
s.lis, err = net.Listen("tcp", "localhost:0")
|
||||
} else {
|
||||
s.lis, err = net.Listen("tcp", ":"+strconv.Itoa(port))
|
||||
s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port))
|
||||
}
|
||||
if err != nil {
|
||||
s.startedErr <- fmt.Errorf("failed to listen: %v", err)
|
||||
|
@ -568,7 +568,7 @@ func TestServerWithMisbehavedClient(t *testing.T) {
|
|||
sent++
|
||||
}
|
||||
// Server sent a resetStream for s already.
|
||||
code := http2RSTErrConvTab[http2.ErrCodeFlowControl]
|
||||
code := http2ErrConvTab[http2.ErrCodeFlowControl]
|
||||
if _, err := io.ReadFull(s, make([]byte, 1)); err != io.EOF || s.statusCode != code {
|
||||
t.Fatalf("%v got err %v with statusCode %d, want err <EOF> with statusCode %d", s, err, s.statusCode, code)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue