Expand stream's flow control in case of an active read. (#1248)

* First commit

* Imported tests from the original PR by @apolcyn.

* Formatting fixes.

* More formating fixes

* more golint

* Make logs more informative.

* post-review update

* Added test to check flow control accounts after sending large messages.

* post-review update

* Empty commit to kickstart travis.

* Post-review update.
This commit is contained in:
MakMukhi 2017-05-23 11:39:15 -07:00 committed by GitHub
parent 79f73d62e5
commit 6dff7c5f33
10 changed files with 591 additions and 62 deletions

View File

@ -278,7 +278,7 @@ type parser struct {
// that the underlying io.Reader must not return an incompatible
// error.
func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
if _, err := p.r.Read(p.header[:]); err != nil {
return 0, nil, err
}
@ -294,7 +294,7 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
msg = make([]byte, int(length))
if _, err := io.ReadFull(p.r, msg); err != nil {
if _, err := p.r.Read(msg); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}

View File

@ -47,6 +47,14 @@ import (
"google.golang.org/grpc/transport"
)
type fullReader struct {
reader io.Reader
}
func (f fullReader) Read(p []byte) (int, error) {
return io.ReadFull(f.reader, p)
}
var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface
func TestSimpleParsing(t *testing.T) {
@ -67,7 +75,7 @@ func TestSimpleParsing(t *testing.T) {
// Check that messages with length >= 2^24 are parsed.
{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
} {
buf := bytes.NewReader(test.p)
buf := fullReader{bytes.NewReader(test.p)}
parser := &parser{r: buf}
pt, b, err := parser.recvMsg(math.MaxInt32)
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
@ -79,7 +87,7 @@ func TestSimpleParsing(t *testing.T) {
func TestMultipleParsing(t *testing.T) {
// Set a byte stream consists of 3 messages with their headers.
p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
b := bytes.NewReader(p)
b := fullReader{bytes.NewReader(p)}
parser := &parser{r: b}
wantRecvs := []struct {

View File

@ -449,6 +449,7 @@ type test struct {
streamServerInt grpc.StreamServerInterceptor
unknownHandler grpc.StreamHandler
sc <-chan grpc.ServiceConfig
customCodec grpc.Codec
serverInitialWindowSize int32
serverInitialConnWindowSize int32
clientInitialWindowSize int32
@ -555,6 +556,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
case "clientTimeoutCreds":
sopts = append(sopts, grpc.Creds(&clientTimeoutCreds{}))
}
if te.customCodec != nil {
sopts = append(sopts, grpc.CustomCodec(te.customCodec))
}
s := grpc.NewServer(sopts...)
te.srv = s
if te.e.httpHandler {
@ -641,6 +645,9 @@ func (te *test) clientConn() *grpc.ClientConn {
if te.perRPCCreds != nil {
opts = append(opts, grpc.WithPerRPCCredentials(te.perRPCCreds))
}
if te.customCodec != nil {
opts = append(opts, grpc.WithCodec(te.customCodec))
}
var err error
te.cc, err = grpc.Dial(te.srvAddr, opts...)
if err != nil {
@ -3271,26 +3278,51 @@ func testServerStreamingConcurrent(t *testing.T, e env) {
}
func generatePayloadSizes() [][]int {
reqSizes := [][]int{
{27182, 8, 1828, 45904},
}
num8KPayloads := 1024
eightKPayloads := []int{}
for i := 0; i < num8KPayloads; i++ {
eightKPayloads = append(eightKPayloads, (1 << 13))
}
reqSizes = append(reqSizes, eightKPayloads)
num2MPayloads := 8
twoMPayloads := []int{}
for i := 0; i < num2MPayloads; i++ {
twoMPayloads = append(twoMPayloads, (1 << 21))
}
reqSizes = append(reqSizes, twoMPayloads)
return reqSizes
}
func TestClientStreaming(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testClientStreaming(t, e)
for _, s := range generatePayloadSizes() {
for _, e := range listTestEnv() {
testClientStreaming(t, e, s)
}
}
}
func testClientStreaming(t *testing.T, e env) {
func testClientStreaming(t *testing.T, e env, sizes []int) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
stream, err := tc.StreamingInputCall(te.ctx)
ctx, _ := context.WithTimeout(te.ctx, time.Second*30)
stream, err := tc.StreamingInputCall(ctx)
if err != nil {
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want <nil>", tc, err)
}
var sum int
for _, s := range reqSizes {
for _, s := range sizes {
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(s))
if err != nil {
t.Fatal(err)

View File

@ -287,3 +287,9 @@ func (st *serverTester) writeRSTStream(streamID uint32, code http2.ErrCode) {
st.t.Fatalf("Error writing RST_STREAM: %v", err)
}
}
func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, padding []byte) {
if err := st.fr.WriteDataPadded(streamID, endStream, data, padding); err != nil {
st.t.Fatalf("Error writing DATA with padding: %v", err)
}
}

View File

@ -58,6 +58,8 @@ const (
defaultServerKeepaliveTime = time.Duration(2 * time.Hour)
defaultServerKeepaliveTimeout = time.Duration(20 * time.Second)
defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute)
// max window limit set by HTTP2 Specs.
maxWindowSize = math.MaxInt32
)
// The following defines various control items which could flow through
@ -167,6 +169,40 @@ type inFlow struct {
// The amount of data the application has consumed but grpc has not sent
// window update for them. Used to reduce window update frequency.
pendingUpdate uint32
// delta is the extra window update given by receiver when an application
// is reading data bigger in size than the inFlow limit.
delta uint32
}
func (f *inFlow) maybeAdjust(n uint32) uint32 {
if n > uint32(math.MaxInt32) {
n = uint32(math.MaxInt32)
}
f.mu.Lock()
defer f.mu.Unlock()
// estSenderQuota is the receiver's view of the maximum number of bytes the sender
// can send without a window update.
estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate))
// estUntransmittedData is the maximum number of bytes the sends might not have put
// on the wire yet. A value of 0 or less means that we have already received all or
// more bytes than the application is requesting to read.
estUntransmittedData := int32(n - f.pendingData) // Casting into int32 since it could be negative.
// This implies that unless we send a window update, the sender won't be able to send all the bytes
// for this message. Therefore we must send an update over the limit since there's an active read
// request from the application.
if estUntransmittedData > estSenderQuota {
// Sender's window shouldn't go more than 2^31 - 1 as speecified in the HTTP spec.
if f.limit+n > maxWindowSize {
f.delta = maxWindowSize - f.limit
} else {
// Send a window update for the whole message and not just the difference between
// estUntransmittedData and estSenderQuota. This will be helpful in case the message
// is padded; We will fallback on the current available window(at least a 1/4th of the limit).
f.delta = n
}
return f.delta
}
return 0
}
// onData is invoked when some data frame is received. It updates pendingData.
@ -174,7 +210,7 @@ func (f *inFlow) onData(n uint32) error {
f.mu.Lock()
defer f.mu.Unlock()
f.pendingData += n
if f.pendingData+f.pendingUpdate > f.limit {
if f.pendingData+f.pendingUpdate > f.limit+f.delta {
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate, f.limit)
}
return nil
@ -189,6 +225,13 @@ func (f *inFlow) onRead(n uint32) uint32 {
return 0
}
f.pendingData -= n
if n > f.delta {
n -= f.delta
f.delta = 0
} else {
f.delta -= n
n = 0
}
f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 {
wu := f.pendingUpdate

View File

@ -316,13 +316,12 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
req := ht.req
s := &Stream{
id: 0, // irrelevant
windowHandler: func(int) {}, // nothing
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
id: 0, // irrelevant
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
}
pr := &peer.Peer{
Addr: ht.RemoteAddr(),
@ -333,7 +332,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
ctx = peer.NewContext(ctx, pr)
s.ctx = newContextWithStream(ctx, s)
s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
s.trReader = &recvBufferReader{ctx: s.ctx, recv: s.buf}
// readerDone is closed when the Body.Read-ing goroutine exits.
readerDone := make(chan struct{})

View File

@ -173,9 +173,9 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
conn, err := dial(ctx, opts.Dialer, addr.Addr)
if err != nil {
if opts.FailOnNonTempDialError {
return nil, connectionErrorf(isTemporary(err), err, "transport: %v", err)
return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err)
}
return nil, connectionErrorf(true, err, "transport: %v", err)
return nil, connectionErrorf(true, err, "transport: Error while dialing %v", err)
}
// Any further errors will close the underlying connection
defer func(conn net.Conn) {
@ -194,7 +194,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
// Credentials handshake errors are typically considered permanent
// to avoid retrying on e.g. bad certificates.
temp := isTemporary(err)
return nil, connectionErrorf(temp, err, "transport: %v", err)
return nil, connectionErrorf(temp, err, "transport: authentication handshake failed: %v", err)
}
isSecure = true
}
@ -269,7 +269,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
n, err := t.conn.Write(clientPreface)
if err != nil {
t.Close()
return nil, connectionErrorf(true, err, "transport: %v", err)
return nil, connectionErrorf(true, err, "transport: failed to write client preface: %v", err)
}
if n != len(clientPreface) {
t.Close()
@ -285,13 +285,13 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
}
if err != nil {
t.Close()
return nil, connectionErrorf(true, err, "transport: %v", err)
return nil, connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(icwz - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close()
return nil, connectionErrorf(true, err, "transport: %v", err)
return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err)
}
}
go t.controller()
@ -316,18 +316,24 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
headerChan: make(chan struct{}),
}
t.nextID += 2
s.windowHandler = func(n int) {
t.updateWindow(s, uint32(n))
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
// The client side stream context should have exactly the same life cycle with the user provided context.
// That means, s.ctx should be read-only. And s.ctx is done iff ctx is done.
// So we use the original context here instead of creating a copy.
s.ctx = ctx
s.dec = &recvBufferReader{
ctx: s.ctx,
goAway: s.goAway,
recv: s.buf,
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
goAway: s.goAway,
recv: s.buf,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
},
}
return s
}
@ -802,6 +808,20 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
return s, ok
}
// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Client) adjustWindow(s *Stream, n uint32) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state == streamDone {
return
}
if w := s.fc.maybeAdjust(n); w > 0 {
t.controlBuf.put(&windowUpdate{s.id, w})
}
}
// updateWindow adjusts the inbound quota for the stream and the transport.
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.

View File

@ -274,10 +274,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
if len(state.mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
}
s.dec = &recvBufferReader{
ctx: s.ctx,
recv: s.buf,
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
recv: s.buf,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
},
}
s.recvCompress = state.encoding
s.method = state.method
@ -316,8 +320,8 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.idle = time.Time{}
}
t.mu.Unlock()
s.windowHandler = func(n int) {
t.updateWindow(s, uint32(n))
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
s.ctx = traceCtx(s.ctx, s.method)
if t.stats != nil {
@ -361,7 +365,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.
return
}
if err != nil {
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
grpclog.Printf("transport: http2Server.HandleStreams failed to read initial settings frame: %v", err)
t.Close()
return
}
@ -435,6 +439,20 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
return s, true
}
// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Server) adjustWindow(s *Stream, n uint32) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state == streamDone {
return
}
if w := s.fc.maybeAdjust(n); w > 0 {
t.controlBuf.put(&windowUpdate{s.id, w})
}
}
// updateWindow adjusts the inbound quota for the stream and the transport.
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.

View File

@ -185,14 +185,17 @@ type Stream struct {
recvCompress string
sendCompress string
buf *recvBuffer
dec io.Reader
trReader io.Reader
fc *inFlow
recvQuota uint32
// TODO: Remote this unused variable.
// The accumulated inbound quota pending for window update.
updateQuota uint32
// The handler to control the window update procedure for both this
// particular stream and the associated transport.
windowHandler func(int)
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if need be.
requestRead func(int)
sendQuotaPool *quotaPool
// Close headerChan to indicate the end of reception of header metadata.
@ -320,16 +323,35 @@ func (s *Stream) write(m recvMsg) {
s.buf.put(&m)
}
// Read reads all the data available for this Stream from the transport and
// Read reads all p bytes from the wire for this stream.
func (s *Stream) Read(p []byte) (n int, err error) {
// Don't request a read if there was an error earlier
if er := s.trReader.(*transportReader).er; er != nil {
return 0, er
}
s.requestRead(len(p))
return io.ReadFull(s.trReader, p)
}
// tranportReader reads all the data available for this Stream from the transport and
// passes them into the decoder, which converts them into a gRPC message stream.
// The error is io.EOF when the stream is done or another non-nil error if
// the stream broke.
func (s *Stream) Read(p []byte) (n int, err error) {
n, err = s.dec.Read(p)
type transportReader struct {
reader io.Reader
// The handler to control the window update procedure for both this
// particular stream and the associated transport.
windowHandler func(int)
er error
}
func (t *transportReader) Read(p []byte) (n int, err error) {
n, err = t.reader.Read(p)
if err != nil {
t.er = err
return
}
s.windowHandler(n)
t.windowHandler(n)
return
}

View File

@ -36,6 +36,8 @@ package transport
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
@ -84,6 +86,9 @@ const (
misbehaved
encodingRequiredStatus
invalidHeaderField
delayRead
delayWrite
pingpong
)
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
@ -94,7 +99,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
resp = expectedResponseLarge
}
p := make([]byte, len(req))
_, err := io.ReadFull(s, p)
_, err := s.Read(p)
if err != nil {
return
}
@ -107,6 +112,25 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
header := make([]byte, 5)
for i := 0; i < 10; i++ {
if _, err := s.Read(header); err != nil {
t.Fatalf("Error on server while reading data header: %v", err)
}
sz := binary.BigEndian.Uint32(header[1:])
msg := make([]byte, int(sz))
if _, err := s.Read(msg); err != nil {
t.Fatalf("Error on server while reading message: %v", err)
}
buf := make([]byte, sz+5)
buf[0] = byte(0)
binary.BigEndian.PutUint32(buf[1:], uint32(sz))
copy(buf[5:], msg)
h.t.Write(s, buf, &Options{})
}
}
// handleStreamSuspension blocks until s.ctx is canceled.
func (h *testStreamHandler) handleStreamSuspension(s *Stream) {
go func() {
@ -159,6 +183,58 @@ func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stre
h.t.writableChan <- 0
}
func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
req := expectedRequest
resp := expectedResponse
if s.Method() == "foo.Large" {
req = expectedRequestLarge
resp = expectedResponseLarge
}
p := make([]byte, len(req))
// Wait before reading. Give time to client to start sending
// before server starts reading.
time.Sleep(2 * time.Second)
_, err := s.Read(p)
if err != nil {
t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err)
return
}
if !bytes.Equal(p, req) {
t.Fatalf("handleStream got %v, want %v", p, req)
}
// send a response back to the client.
h.t.Write(s, resp, &Options{})
// send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) {
req := expectedRequest
resp := expectedResponse
if s.Method() == "foo.Large" {
req = expectedRequestLarge
resp = expectedResponseLarge
}
p := make([]byte, len(req))
_, err := s.Read(p)
if err != nil {
t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err)
return
}
if !bytes.Equal(p, req) {
t.Fatalf("handleStream got %v, want %v", p, req)
}
// Wait before sending. Give time to client to start reading
// before server starts sending.
time.Sleep(2 * time.Second)
h.t.Write(s, resp, &Options{})
// send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
// start starts server. Other goroutines should block on s.readyChan for further operations.
func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) {
var err error
@ -221,6 +297,24 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case delayRead:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamDelayRead(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case delayWrite:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamDelayWrite(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case pingpong:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamPingPong(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
default:
go transport.HandleStreams(func(s *Stream) {
go h.handleStream(t, s)
@ -696,11 +790,11 @@ func TestClientSendAndReceive(t *testing.T) {
t.Fatalf("failed to send data: %v", err)
}
p := make([]byte, len(expectedResponse))
_, recvErr := io.ReadFull(s1, p)
_, recvErr := s1.Read(p)
if recvErr != nil || !bytes.Equal(p, expectedResponse) {
t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse)
}
_, recvErr = io.ReadFull(s1, p)
_, recvErr = s1.Read(p)
if recvErr != io.EOF {
t.Fatalf("Error: %v; want <EOF>", recvErr)
}
@ -736,9 +830,9 @@ func performOneRPC(ct ClientTransport) {
//
// Read response
p := make([]byte, len(expectedResponse))
io.ReadFull(s, p)
s.Read(p)
// Read io.EOF
io.ReadFull(s, p)
s.Read(p)
}
}
@ -777,10 +871,80 @@ func TestLargeMessage(t *testing.T) {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
if _, err = io.ReadFull(s, p); err != io.EOF {
if _, err = s.Read(p); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
}
wg.Wait()
ct.Close()
server.stop()
}
func TestLargeMessageWithDelayRead(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, delayRead)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Large",
}
var wg sync.WaitGroup
for i := 0; i < 2; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
// Give time to server to begin sending before client starts reading.
time.Sleep(2 * time.Second)
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
}
if _, err = s.Read(p); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
}
wg.Wait()
ct.Close()
server.stop()
}
func TestLargeMessageDelayWrite(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, delayWrite)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Large",
}
var wg sync.WaitGroup
for i := 0; i < 2; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
// Give time to server to start reading before client starts sending.
time.Sleep(2 * time.Second)
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("io.ReadFull(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
if _, err = s.Read(p); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
@ -823,10 +987,10 @@ func TestGracefulClose(t *testing.T) {
t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponse))
if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponse) {
t.Fatalf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponse) {
t.Fatalf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
if _, err = io.ReadFull(s, p); err != io.EOF {
if _, err = s.Read(p); err != io.EOF {
t.Fatalf("Failed to complete the stream %v; want <EOF>", err)
}
wg.Wait()
@ -1074,7 +1238,7 @@ func TestServerWithMisbehavedClient(t *testing.T) {
}
// Server sent a resetStream for s already.
code := http2ErrConvTab[http2.ErrCodeFlowControl]
if _, err := io.ReadFull(s, make([]byte, 1)); err != io.EOF {
if _, err := s.Read(make([]byte, 1)); err != io.EOF {
t.Fatalf("%v got err %v want <EOF>", s, err)
}
if s.status.Code() != code {
@ -1125,7 +1289,7 @@ func TestClientWithMisbehavedServer(t *testing.T) {
// Read without window update.
for {
p := make([]byte, http2MaxFrameLen)
if _, err = s.dec.Read(p); err != nil {
if _, err = s.trReader.(*transportReader).reader.Read(p); err != nil {
break
}
}
@ -1184,7 +1348,7 @@ func TestEncodingRequiredStatus(t *testing.T) {
t.Fatalf("Failed to write the request: %v", err)
}
p := make([]byte, http2MaxFrameLen)
if _, err := s.dec.Read(p); err != io.EOF {
if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF {
t.Fatalf("Read got error %v, want %v", err, io.EOF)
}
if !reflect.DeepEqual(s.Status(), encodingTestStatus) {
@ -1212,7 +1376,7 @@ func TestInvalidHeaderField(t *testing.T) {
t.Fatalf("Failed to write the request: %v", err)
}
p := make([]byte, http2MaxFrameLen)
_, err = s.dec.Read(p)
_, err = s.trReader.(*transportReader).Read(p)
if se, ok := err.(StreamError); !ok || se.Code != codes.FailedPrecondition || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.FailedPrecondition, expectedInvalidHeaderField)
}
@ -1269,6 +1433,13 @@ func TestContextErr(t *testing.T) {
}
}
func max(a, b int32) int32 {
if a > b {
return a
}
return b
}
type windowSizeConfig struct {
serverStream int32
serverConn int32
@ -1348,6 +1519,7 @@ func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) {
}
return false, nil
})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
serverSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire())
if err != nil {
@ -1395,6 +1567,166 @@ func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) {
}
}
// Check accounting on both sides after sending and receiving large messages.
func TestAccountCheckExpandingWindow(t *testing.T) {
server, client := setUp(t, 0, 0, pingpong)
defer server.stop()
defer client.Close()
waitWhileTrue(t, func() (bool, error) {
server.mu.Lock()
defer server.mu.Unlock()
if len(server.conns) == 0 {
return true, fmt.Errorf("timed out while waiting for server transport to be created")
}
return false, nil
})
var st *http2Server
server.mu.Lock()
for k := range server.conns {
st = k.(*http2Server)
}
server.mu.Unlock()
ct := client.(*http2Client)
cstream, err := client.NewStream(context.Background(), &CallHdr{Flush: true})
if err != nil {
t.Fatalf("Failed to create stream. Err: %v", err)
}
msgSize := 65535 * 16 * 2
msg := make([]byte, msgSize)
buf := make([]byte, msgSize+5)
buf[0] = byte(0)
binary.BigEndian.PutUint32(buf[1:], uint32(msgSize))
copy(buf[5:], msg)
opts := Options{}
header := make([]byte, 5)
for i := 1; i <= 10; i++ {
if err := ct.Write(cstream, buf, &opts); err != nil {
t.Fatalf("Error on client while writing message: %v", err)
}
if _, err := cstream.Read(header); err != nil {
t.Fatalf("Error on client while reading data frame header: %v", err)
}
sz := binary.BigEndian.Uint32(header[1:])
recvMsg := make([]byte, int(sz))
if _, err := cstream.Read(recvMsg); err != nil {
t.Fatalf("Error on client while reading data: %v", err)
}
if len(recvMsg) != len(msg) {
t.Fatalf("Length of message received by client: %v, want: %v", len(recvMsg), len(msg))
}
}
var sstream *Stream
st.mu.Lock()
for _, v := range st.activeStreams {
sstream = v
}
st.mu.Unlock()
waitWhileTrue(t, func() (bool, error) {
// Check that pendingData and delta on flow control windows on both sides are 0.
cstream.fc.mu.Lock()
if cstream.fc.delta != 0 {
cstream.fc.mu.Unlock()
return true, fmt.Errorf("delta on flow control window of client stream is non-zero")
}
if cstream.fc.pendingData != 0 {
cstream.fc.mu.Unlock()
return true, fmt.Errorf("pendingData on flow control window of client stream is non-zero")
}
cstream.fc.mu.Unlock()
sstream.fc.mu.Lock()
if sstream.fc.delta != 0 {
sstream.fc.mu.Unlock()
return true, fmt.Errorf("delta on flow control window of server stream is non-zero")
}
if sstream.fc.pendingData != 0 {
sstream.fc.mu.Unlock()
return true, fmt.Errorf("pendingData on flow control window of sercer stream is non-zero")
}
sstream.fc.mu.Unlock()
ct.fc.mu.Lock()
if ct.fc.delta != 0 {
ct.fc.mu.Unlock()
return true, fmt.Errorf("delta on flow control window of client transport is non-zero")
}
if ct.fc.pendingData != 0 {
ct.fc.mu.Unlock()
return true, fmt.Errorf("pendingData on flow control window of client transport is non-zero")
}
ct.fc.mu.Unlock()
st.fc.mu.Lock()
if st.fc.delta != 0 {
st.fc.mu.Unlock()
return true, fmt.Errorf("delta on flow control window of server transport is non-zero")
}
if st.fc.pendingData != 0 {
st.fc.mu.Unlock()
return true, fmt.Errorf("pendingData on flow control window of server transport is non-zero")
}
st.fc.mu.Unlock()
// Check flow conrtrol window on client stream is equal to out flow on server stream.
ctx, _ := context.WithTimeout(context.Background(), time.Second)
serverStreamSendQuota, err := wait(ctx, nil, nil, nil, sstream.sendQuotaPool.acquire())
if err != nil {
return true, fmt.Errorf("error while acquiring server stream send quota. Err: %v", err)
}
sstream.sendQuotaPool.add(serverStreamSendQuota)
cstream.fc.mu.Lock()
if uint32(serverStreamSendQuota) != cstream.fc.limit-cstream.fc.pendingUpdate {
cstream.fc.mu.Unlock()
return true, fmt.Errorf("server stream outflow: %v, estimated by client: %v", serverStreamSendQuota, cstream.fc.limit-cstream.fc.pendingUpdate)
}
cstream.fc.mu.Unlock()
// Check flow control window on server stream is equal to out flow on client stream.
ctx, _ = context.WithTimeout(context.Background(), time.Second)
clientStreamSendQuota, err := wait(ctx, nil, nil, nil, cstream.sendQuotaPool.acquire())
if err != nil {
return true, fmt.Errorf("error while acquiring client stream send quota. Err: %v", err)
}
cstream.sendQuotaPool.add(clientStreamSendQuota)
sstream.fc.mu.Lock()
if uint32(clientStreamSendQuota) != sstream.fc.limit-sstream.fc.pendingUpdate {
sstream.fc.mu.Unlock()
return true, fmt.Errorf("client stream outflow: %v. estimated by server: %v", clientStreamSendQuota, sstream.fc.limit-sstream.fc.pendingUpdate)
}
sstream.fc.mu.Unlock()
// Check flow control window on client transport is equal to out flow of server transport.
ctx, _ = context.WithTimeout(context.Background(), time.Second)
serverTrSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire())
if err != nil {
return true, fmt.Errorf("error while acquring server transport send quota. Err: %v", err)
}
st.sendQuotaPool.add(serverTrSendQuota)
ct.fc.mu.Lock()
if uint32(serverTrSendQuota) != ct.fc.limit-ct.fc.pendingUpdate {
ct.fc.mu.Unlock()
return true, fmt.Errorf("server transport outflow: %v, estimated by client: %v", serverTrSendQuota, ct.fc.limit-ct.fc.pendingUpdate)
}
ct.fc.mu.Unlock()
// Check flow control window on server transport is equal to out flow of client transport.
ctx, _ = context.WithTimeout(context.Background(), time.Second)
clientTrSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire())
if err != nil {
return true, fmt.Errorf("error while acquiring client transport send quota. Err: %v", err)
}
ct.sendQuotaPool.add(clientTrSendQuota)
st.fc.mu.Lock()
if uint32(clientTrSendQuota) != st.fc.limit-st.fc.pendingUpdate {
st.fc.mu.Unlock()
return true, fmt.Errorf("client transport outflow: %v, estimated by client: %v", clientTrSendQuota, st.fc.limit-st.fc.pendingUpdate)
}
st.fc.mu.Unlock()
return false, nil
})
}
func waitWhileTrue(t *testing.T, condition func() (bool, error)) {
var (
wait bool
@ -1576,7 +1908,8 @@ func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders)
stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh)
defer cleanUp()
want := httpStatusConvTab[httpStatus]
_, err := stream.Read([]byte{})
buf := make([]byte, 8)
_, err := stream.Read(buf)
if err == nil {
t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want)
}
@ -1592,7 +1925,8 @@ func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders)
func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) {
stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader)
defer cleanUp()
_, err := stream.Read([]byte{})
buf := make([]byte, 8)
_, err := stream.Read(buf)
if err != io.EOF {
t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err)
}
@ -1607,3 +1941,50 @@ func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) {
func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) {
testHTTPToGRPCStatusMapping(t, http.StatusUnauthorized, writeTwoHeaders)
}
// If any error occurs on a call to Stream.Read, future calls
// should continue to return that same error.
func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
testRecvBuffer := newRecvBuffer()
s := &Stream{
ctx: context.Background(),
goAway: make(chan struct{}),
buf: testRecvBuffer,
requestRead: func(int) {},
}
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
goAway: s.goAway,
recv: s.buf,
},
windowHandler: func(int) {},
}
testData := make([]byte, 1)
testData[0] = 5
testErr := errors.New("test error")
s.write(recvMsg{data: testData, err: testErr})
inBuf := make([]byte, 1)
actualCount, actualErr := s.Read(inBuf)
if actualCount != 0 {
t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount)
}
if actualErr.Error() != testErr.Error() {
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
}
s.write(recvMsg{data: testData, err: nil})
s.write(recvMsg{data: testData, err: errors.New("different error from first")})
for i := 0; i < 2; i++ {
inBuf := make([]byte, 1)
actualCount, actualErr := s.Read(inBuf)
if actualCount != 0 {
t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount)
}
if actualErr.Error() != testErr.Error() {
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
}
}
}