diff --git a/call_test.go b/call_test.go index 2bcea807f..48134c4c2 100644 --- a/call_test.go +++ b/call_test.go @@ -66,17 +66,16 @@ type testStreamHandler struct { } func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { - p := &parser{r: s} for { - pf, req, err := p.recvMsg(math.MaxInt32) + isCompressed, req, err := recvMsg(s, math.MaxInt32) if err == io.EOF { break } if err != nil { return } - if pf != compressionNone { - t.Errorf("Received the mistaken message format %d, want %d", pf, compressionNone) + if isCompressed { + t.Errorf("Received compressed message want non-compressed message") return } var v string @@ -105,12 +104,12 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { } } // send a response back to end the stream. - hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) + data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) if err != nil { t.Errorf("Failed to encode the response: %v", err) return } - h.t.Write(s, hdr, data, &transport.Options{}) + h.t.Write(s, data, &transport.Options{}) h.t.WriteStatus(s, status.New(codes.OK, "")) } diff --git a/internal/msgdecoder/msgdecoder.go b/internal/msgdecoder/msgdecoder.go new file mode 100644 index 000000000..bbcc4797d --- /dev/null +++ b/internal/msgdecoder/msgdecoder.go @@ -0,0 +1,203 @@ +/* + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package msgdecoder contains the logic to deconstruct a gRPC-message. +package msgdecoder + +import ( + "encoding/binary" +) + +// RecvMsg is a message constructed from the incoming +// bytes on the transport for a stream. +// An instance of RecvMsg will contain only one of the +// following: message header related fields, data slice +// or error. +type RecvMsg struct { + // Following three are message header related + // fields. + // true if the message was compressed by the other + // side. + IsCompressed bool + // Length of the message. + Length int + // Overhead is the length of message header(5 bytes) + // plus padding. + Overhead int + + // Data payload of the message. + Data []byte + + // Err occurred while reading. + // nil: received some data + // io.EOF: stream is completed. data is nil. + // other non-nil error: transport failure. data is nil. + Err error + + Next *RecvMsg +} + +// RecvMsgList is a linked-list of RecvMsg. +type RecvMsgList struct { + head *RecvMsg + tail *RecvMsg +} + +// IsEmpty returns true when l is empty. +func (l *RecvMsgList) IsEmpty() bool { + if l.tail == nil { + return true + } + return false +} + +// Enqueue adds r to l at the back. +func (l *RecvMsgList) Enqueue(r *RecvMsg) { + if l.IsEmpty() { + l.head, l.tail = r, r + return + } + t := l.tail + l.tail = r + t.Next = r +} + +// Dequeue removes a RcvMsg from the end of l. +func (l *RecvMsgList) Dequeue() *RecvMsg { + if l.head == nil { + // Note to developer: Instead of calling isEmpty() which + // checks the same condition on l.tail, we check it directly + // on l.head so that in non-nil cases, there aren't cache misses. + return nil + } + r := l.head + l.head = l.head.Next + if l.head == nil { + l.tail = nil + } + return r +} + +// MessageDecoder decodes bytes from HTTP2 data frames +// and constructs a gRPC message which is then put in a +// buffer that application(RPCs) read from. +// gRPC Messages: +// First 5 bytes is the message header: +// First byte: Payload format. +// Next 4 bytes: Length of the message. +// Rest of the bytes is the message payload. +// +// TODO(mmukhi): Write unit tests. +type MessageDecoder struct { + // current message being read by the transport. + current *RecvMsg + dataOfst int + padding int + // hdr stores the message header as it is beind received by the transport. + hdr []byte + hdrOfst int + // Callback used to send decoded messages. + dispatch func(*RecvMsg) +} + +// NewMessageDecoder creates an instance of MessageDecoder. It takes a callback +// which is called to dispatch finished headers and messages to the application. +func NewMessageDecoder(dispatch func(*RecvMsg)) *MessageDecoder { + return &MessageDecoder{ + hdr: make([]byte, 5), + dispatch: dispatch, + } +} + +// Decode consumes bytes from a HTTP2 data frame to create gRPC messages. +func (m *MessageDecoder) Decode(b []byte, padding int) { + m.padding += padding + for len(b) > 0 { + // Case 1: A complete message hdr was received earlier. + if m.current != nil { + n := copy(m.current.Data[m.dataOfst:], b) + m.dataOfst += n + b = b[n:] + if m.dataOfst == len(m.current.Data) { // Message is complete. + m.dispatch(m.current) + m.current = nil + m.dataOfst = 0 + } + continue + } + // Case 2a: No message header has been received yet. + if m.hdrOfst == 0 { + // case 2a.1: b has the whole header + if len(b) >= 5 { + m.parseHeader(b[:5]) + b = b[5:] + continue + } + // case 2a.2: b has partial header + n := copy(m.hdr, b) + m.hdrOfst = n + b = b[n:] + continue + } + // Case 2b: Partial message header was received earlier. + n := copy(m.hdr[m.hdrOfst:], b) + m.hdrOfst += n + b = b[n:] + if m.hdrOfst == 5 { // hdr is complete. + m.hdrOfst = 0 + m.parseHeader(m.hdr) + } + } +} + +func (m *MessageDecoder) parseHeader(b []byte) { + length := int(binary.BigEndian.Uint32(b[1:5])) + hdr := &RecvMsg{ + IsCompressed: int(b[0]) == 1, + Length: length, + Overhead: m.padding + 5, + } + m.padding = 0 + // Dispatch the information retreived from message header so + // that the RPC goroutine can send a proactive window update as we + // wait for the rest of it. + m.dispatch(hdr) + if length == 0 { + m.dispatch(&RecvMsg{}) + return + } + m.current = &RecvMsg{ + Data: getMem(length), + } +} + +func getMem(l int) []byte { + // TODO(mmukhi): Reuse this memory. + return make([]byte, l) +} + +// CreateMessageHeader creates a gRPC-specific message header. +func CreateMessageHeader(l int, isCompressed bool) []byte { + // TODO(mmukhi): Investigate if this memory is worth + // reusing. + hdr := make([]byte, 5) + if isCompressed { + hdr[0] = byte(1) + } + binary.BigEndian.PutUint32(hdr[1:], uint32(l)) + return hdr +} diff --git a/internal/msgdecoder/msgdecoder_test.go b/internal/msgdecoder/msgdecoder_test.go new file mode 100644 index 000000000..caa498ea1 --- /dev/null +++ b/internal/msgdecoder/msgdecoder_test.go @@ -0,0 +1,81 @@ +/* + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package msgdecoder + +import ( + "encoding/binary" + "reflect" + "testing" +) + +func TestMessageDecoder(t *testing.T) { + for _, test := range []struct { + numFrames int + data []string + }{ + {1, []string{"abc"}}, // One message per frame. + {1, []string{"abc", "def", "ghi"}}, // Multiple messages per frame. + {3, []string{"a", "bcdef", "ghif"}}, // Multiple messages over multiple frames. + } { + var want []*RecvMsg + for _, d := range test.data { + want = append(want, &RecvMsg{Length: len(d), Overhead: 5}) + want = append(want, &RecvMsg{Data: []byte(d)}) + } + var got []*RecvMsg + dcdr := NewMessageDecoder(func(r *RecvMsg) { got = append(got, r) }) + for _, fr := range createFrames(test.numFrames, test.data) { + dcdr.Decode(fr, 0) + } + if !match(got, want) { + t.Fatalf("got: %v, want: %v", got, want) + } + } +} + +func match(got, want []*RecvMsg) bool { + for i, v := range got { + if !reflect.DeepEqual(v, want[i]) { + return false + } + } + return true +} + +func createFrames(n int, msgs []string) [][]byte { + var b []byte + for _, m := range msgs { + payload := []byte(m) + hdr := make([]byte, 5) + binary.BigEndian.PutUint32(hdr[1:], uint32(len(payload))) + b = append(b, hdr...) + b = append(b, payload...) + } + // break b into n parts. + var result [][]byte + batch := len(b) / n + for len(b) != 0 { + sz := batch + if len(b) < sz { + sz = len(b) + } + result = append(result, b[:sz]) + b = b[sz:] + } + return result +} diff --git a/rpc_util.go b/rpc_util.go index f08e646f8..90d1b849c 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -21,7 +21,6 @@ package grpc import ( "bytes" "compress/gzip" - "encoding/binary" "fmt" "io" "io/ioutil" @@ -415,85 +414,39 @@ func (o CustomCodecCallOption) before(c *callInfo) error { } func (o CustomCodecCallOption) after(c *callInfo) {} -// The format of the payload: compressed or not? -type payloadFormat uint8 - -const ( - compressionNone payloadFormat = iota // no compression - compressionMade -) - -// parser reads complete gRPC messages from the underlying reader. -type parser struct { - // r is the underlying reader. - // See the comment on recvMsg for the permissible - // error types. - r io.Reader - - // The header of a gRPC message. Find more detail at - // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md - header [5]byte -} - // recvMsg reads a complete gRPC message from the stream. // -// It returns the message and its payload (compression/encoding) -// format. The caller owns the returned msg memory. +// It returns a flag set to true if message was compressed, +// the message as a byte slice or error if so. +// The caller owns the returned msg memory. // // If there is an error, possible values are: // * io.EOF, when no messages remain // * io.ErrUnexpectedEOF // * of type transport.ConnectionError // * of type transport.StreamError -// No other error values or types must be returned, which also means -// that the underlying io.Reader must not return an incompatible -// error. -func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) { - if _, err := p.r.Read(p.header[:]); err != nil { - return 0, nil, err +// No other error values or types must be returned. +func recvMsg(s *transport.Stream, maxRecvMsgSize int) (bool, []byte, error) { + isCompressed, msg, err := s.Read(maxRecvMsgSize) + if err != nil { + return false, nil, err } - - pf = payloadFormat(p.header[0]) - length := binary.BigEndian.Uint32(p.header[1:]) - - if length == 0 { - return pf, nil, nil - } - if int64(length) > int64(maxInt) { - return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt) - } - if int(length) > maxReceiveMessageSize { - return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize) - } - // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead - // of making it for each message: - msg = make([]byte, int(length)) - if _, err := p.r.Read(msg); err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return 0, nil, err - } - return pf, msg, nil + return isCompressed, msg, nil } -// encode serializes msg and returns a buffer of message header and a buffer of msg. -// If msg is nil, it generates the message header and an empty msg buffer. +// encode serializes msg and returns a buffer of msg. +// If msg is nil, it generates an empty buffer. // TODO(ddyihai): eliminate extra Compressor parameter. -func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) { +func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, error) { var ( b []byte cbuf *bytes.Buffer ) - const ( - payloadLen = 1 - sizeLen = 4 - ) if msg != nil { var err error b, err = c.Marshal(msg) if err != nil { - return nil, nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) + return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) } if outPayload != nil { outPayload.Payload = msg @@ -507,49 +460,36 @@ func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPa if compressor != nil { z, _ := compressor.Compress(cbuf) if _, err := z.Write(b); err != nil { - return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } z.Close() } else { // If Compressor is not set by UseCompressor, use default Compressor if err := cp.Do(cbuf, b); err != nil { - return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) + return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } } b = cbuf.Bytes() } } if uint(len(b)) > math.MaxUint32 { - return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) + return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) } - bufHeader := make([]byte, payloadLen+sizeLen) - if compressor != nil || cp != nil { - bufHeader[0] = byte(compressionMade) - } else { - bufHeader[0] = byte(compressionNone) - } - - // Write length of b into buf - binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b))) if outPayload != nil { - outPayload.WireLength = payloadLen + sizeLen + len(b) + // A 5 byte gRPC-specific message header will added to this message + // before it's put on wire. + outPayload.WireLength = 5 + len(b) } - return bufHeader, b, nil + return b, nil } -func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status { - switch pf { - case compressionNone: - case compressionMade: - if recvCompress == "" || recvCompress == encoding.Identity { - return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding") - } - if !haveCompressor { - return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) - } - default: - return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf) +func checkRecvPayload(recvCompress string, haveCompressor bool) *status.Status { + if recvCompress == "" || recvCompress == encoding.Identity { + return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding") + } + if !haveCompressor { + return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) } return nil } @@ -557,8 +497,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool // For the two compressor parameters, both should not be set, but if they are, // dc takes precedence over compressor. // TODO(dfawley): wrap the old compressor/decompressor using the new API? -func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error { - pf, d, err := p.recvMsg(maxReceiveMessageSize) +func recv(c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error { + isCompressed, d, err := recvMsg(s, maxReceiveMessageSize) if err != nil { return err } @@ -566,11 +506,10 @@ func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interf inPayload.WireLength = len(d) } - if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil { - return st.Err() - } - - if pf == compressionMade { + if isCompressed { + if st := checkRecvPayload(s.RecvCompress(), compressor != nil || dc != nil); st != nil { + return st.Err() + } // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor, // use this decompressor as the default. if dc != nil { @@ -588,11 +527,11 @@ func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interf return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } - } - if len(d) > maxReceiveMessageSize { - // TODO: Revisit the error code. Currently keep it consistent with java - // implementation. - return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize) + if len(d) > maxReceiveMessageSize { + // TODO: Revisit the error code. Currently keep it consistent with java + // implementation. + return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize) + } } if err := c.Unmarshal(d, m); err != nil { return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) diff --git a/rpc_util_test.go b/rpc_util_test.go index 770e850c6..2cf2b43a5 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -22,7 +22,6 @@ import ( "bytes" "compress/gzip" "io" - "math" "reflect" "testing" @@ -45,77 +44,20 @@ func (f fullReader) Read(p []byte) (int, error) { var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface -func TestSimpleParsing(t *testing.T) { - bigMsg := bytes.Repeat([]byte{'x'}, 1<<24) - for _, test := range []struct { - // input - p []byte - // outputs - err error - b []byte - pt payloadFormat - }{ - {nil, io.EOF, nil, compressionNone}, - {[]byte{0, 0, 0, 0, 0}, nil, nil, compressionNone}, - {[]byte{0, 0, 0, 0, 1, 'a'}, nil, []byte{'a'}, compressionNone}, - {[]byte{1, 0}, io.ErrUnexpectedEOF, nil, compressionNone}, - {[]byte{0, 0, 0, 0, 10, 'a'}, io.ErrUnexpectedEOF, nil, compressionNone}, - // Check that messages with length >= 2^24 are parsed. - {append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone}, - } { - buf := fullReader{bytes.NewReader(test.p)} - parser := &parser{r: buf} - pt, b, err := parser.recvMsg(math.MaxInt32) - if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt { - t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) - } - } -} - -func TestMultipleParsing(t *testing.T) { - // Set a byte stream consists of 3 messages with their headers. - p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'} - b := fullReader{bytes.NewReader(p)} - parser := &parser{r: b} - - wantRecvs := []struct { - pt payloadFormat - data []byte - }{ - {compressionNone, []byte("a")}, - {compressionNone, []byte("bc")}, - {compressionNone, []byte("d")}, - } - for i, want := range wantRecvs { - pt, data, err := parser.recvMsg(math.MaxInt32) - if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) { - t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, ", - i, p, pt, data, err, want.pt, want.data) - } - } - - pt, data, err := parser.recvMsg(math.MaxInt32) - if err != io.EOF { - t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v", - len(wantRecvs), p, pt, data, err, io.EOF) - } -} - func TestEncode(t *testing.T) { for _, test := range []struct { // input msg proto.Message cp Compressor // outputs - hdr []byte data []byte err error }{ - {nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil}, + {nil, nil, []byte{}, nil}, } { - hdr, data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil) - if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) { - t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err) + data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil) + if err != test.err || !bytes.Equal(data, test.data) { + t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, data, err, test.data, test.err) } } } @@ -214,8 +156,11 @@ func TestParseDialTarget(t *testing.T) { func bmEncode(b *testing.B, mSize int) { cdc := encoding.GetCodec(protoenc.Name) msg := &perfpb.Buffer{Body: make([]byte, mSize)} - encodeHdr, encodeData, _ := encode(cdc, msg, nil, nil, nil) - encodedSz := int64(len(encodeHdr) + len(encodeData)) + encodeData, _ := encode(cdc, msg, nil, nil, nil) + // 5 bytes of gRPC-specific message header + // is added to the message before it is written + // to the wire. + encodedSz := int64(5 + len(encodeData)) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/server.go b/server.go index a7ef6cc25..6a59b7cff 100644 --- a/server.go +++ b/server.go @@ -831,7 +831,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str if s.opts.statsHandler != nil { outPayload = &stats.OutPayload{} } - hdr, data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp) + data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp) if err != nil { grpclog.Errorln("grpc: server failed to encode response: ", err) return err @@ -839,7 +839,8 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str if len(data) > s.opts.maxSendMessageSize { return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize) } - err = t.Write(stream, hdr, data, opts) + opts.IsCompressed = cp != nil || comp != nil + err = t.Write(stream, data, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() s.opts.statsHandler.HandleRPC(stream.Context(), outPayload) @@ -924,8 +925,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } } - p := &parser{r: stream} - pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize) + isCompressed, req, err := recvMsg(stream, s.opts.maxReceiveMessageSize) if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -955,12 +955,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if channelz.IsOn() { t.IncrMsgRecv() } - if st := checkRecvPayload(pf, stream.RecvCompress(), dc != nil || decomp != nil); st != nil { - if e := t.WriteStatus(stream, st); e != nil { - grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) - } - return st.Err() - } var inPayload *stats.InPayload if sh != nil { inPayload = &stats.InPayload{ @@ -971,7 +965,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if inPayload != nil { inPayload.WireLength = len(req) } - if pf == compressionMade { + if isCompressed { + if st := checkRecvPayload(stream.RecvCompress(), dc != nil || decomp != nil); st != nil { + return st.Err() + } var err error if dc != nil { req, err = dc.Do(bytes.NewReader(req)) @@ -985,11 +982,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } - } - if len(req) > s.opts.maxReceiveMessageSize { - // TODO: Revisit the error code. Currently keep it consistent with - // java implementation. - return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize) + if len(req) > s.opts.maxReceiveMessageSize { + // TODO: Revisit the error code. Currently keep it consistent with + // java implementation. + return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize) + } } if err := s.getCodec(stream.ContentSubtype()).Unmarshal(req, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) @@ -1100,7 +1097,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ctx: ctx, t: t, s: stream, - p: &parser{r: stream}, codec: s.getCodec(stream.ContentSubtype()), maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, diff --git a/stream.go b/stream.go index 82921a15a..1b5dd4956 100644 --- a/stream.go +++ b/stream.go @@ -290,7 +290,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth attempt: &csAttempt{ t: t, s: s, - p: &parser{r: s}, done: done, dc: cc.dopts.dc, ctx: ctx, @@ -347,7 +346,6 @@ type csAttempt struct { cs *clientStream t transport.ClientTransport s *transport.Stream - p *parser done func(balancer.DoneInfo) dc Decompressor @@ -472,7 +470,7 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) { Client: true, } } - hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp) + data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp) if err != nil { return err } @@ -482,7 +480,11 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) { if !cs.desc.ClientStreams { cs.sentLast = true } - err = a.t.Write(a.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams}) + opts := &transport.Options{ + Last: !cs.desc.ClientStreams, + IsCompressed: cs.cp != nil || cs.comp != nil, + } + err = a.t.Write(a.s, data, opts) if err == nil { if outPayload != nil { outPayload.SentTime = time.Now() @@ -526,7 +528,7 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) { // Only initialize this state once per stream. a.decompSet = true } - err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, inPayload, a.decomp) + err = recv(cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, inPayload, a.decomp) if err != nil { if err == io.EOF { if statusErr := a.s.Status().Err(); statusErr != nil { @@ -556,7 +558,7 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) { // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, nil, a.decomp) + err = recv(cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, nil, a.decomp) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) } @@ -572,7 +574,7 @@ func (a *csAttempt) closeSend() { return } cs.sentLast = true - cs.attempt.t.Write(cs.attempt.s, nil, nil, &transport.Options{Last: true}) + cs.attempt.t.Write(cs.attempt.s, nil, &transport.Options{Last: true}) // We ignore errors from Write. Any error it would return would also be // returned by a subsequent RecvMsg call, and the user is supposed to always // finish the stream by calling RecvMsg until it returns err != nil. @@ -635,7 +637,6 @@ type serverStream struct { ctx context.Context t transport.ServerTransport s *transport.Stream - p *parser codec baseCodec cp Compressor @@ -700,14 +701,18 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { if ss.statsHandler != nil { outPayload = &stats.OutPayload{} } - hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp) + data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp) if err != nil { return err } if len(data) > ss.maxSendMessageSize { return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) } - if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil { + opts := &transport.Options{ + Last: false, + IsCompressed: ss.cp != nil || ss.comp != nil, + } + if err := ss.t.Write(ss.s, data, opts); err != nil { return toRPCErr(err) } if outPayload != nil { @@ -743,7 +748,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { if ss.statsHandler != nil { inPayload = &stats.InPayload{} } - if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil { + if err := recv(ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil { if err == io.EOF { return err } diff --git a/transport/flowcontrol.go b/transport/flowcontrol.go index 378f5c450..babdfabdd 100644 --- a/transport/flowcontrol.go +++ b/transport/flowcontrol.go @@ -21,7 +21,6 @@ package transport import ( "fmt" "math" - "sync" "sync/atomic" "time" ) @@ -96,35 +95,39 @@ func (w *writeQuota) replenish(n int) { } type trInFlow struct { - limit uint32 - unacked uint32 - effectiveWindowSize uint32 + limit uint32 // accessed by reader goroutine. + unacked uint32 // accessed by reader goroutine. + effectiveWindowSize uint32 // accessed by reader and channelz request goroutine. + // Callback used to schedule window update. + scheduleWU func(uint32) } -func (f *trInFlow) newLimit(n uint32) uint32 { - d := n - f.limit +// Sets the new limit. +func (f *trInFlow) newLimit(n uint32) { + if n > f.limit { + f.scheduleWU(n - f.limit) + } f.limit = n f.updateEffectiveWindowSize() - return d } -func (f *trInFlow) onData(n uint32) uint32 { +func (f *trInFlow) onData(n uint32) { f.unacked += n if f.unacked >= f.limit/4 { w := f.unacked f.unacked = 0 - f.updateEffectiveWindowSize() - return w + f.scheduleWU(w) } f.updateEffectiveWindowSize() - return 0 } -func (f *trInFlow) reset() uint32 { - w := f.unacked +func (f *trInFlow) reset() { + if f.unacked == 0 { + return + } + f.scheduleWU(f.unacked) f.unacked = 0 f.updateEffectiveWindowSize() - return w } func (f *trInFlow) updateEffectiveWindowSize() { @@ -135,102 +138,57 @@ func (f *trInFlow) getSize() uint32 { return atomic.LoadUint32(&f.effectiveWindowSize) } -// TODO(mmukhi): Simplify this code. -// inFlow deals with inbound flow control -type inFlow struct { - mu sync.Mutex +// stInFlow deals with inbound flow control for stream. +// It can be simultaneously read by transport's reader +// goroutine and an RPC's goroutine. +// It is protected by the lock in stream that owns it. +type stInFlow struct { + // rcvd is the bytes of data that this end-point has + // received from the perspective of other side. + // This can go negative. It must be Accessed atomically. + // Needs to be aligned because of golang bug with atomics: + // https://golang.org/pkg/sync/atomic/#pkg-note-BUG + rcvd int64 // The inbound flow control limit for pending data. limit uint32 - // pendingData is the overall data which have been received but not been - // consumed by applications. - pendingData uint32 - // The amount of data the application has consumed but grpc has not sent - // window update for them. Used to reduce window update frequency. - pendingUpdate uint32 - // delta is the extra window update given by receiver when an application - // is reading data bigger in size than the inFlow limit. - delta uint32 + // number of bytes received so far, this should be accessed + // number of bytes that have been read by the RPC. + read uint32 + // a window update should be sent when the RPC has + // read these many bytes. + // TODO(mmukhi, dfawley): Does this have to be limit/4? + // Keeping it a constant makes implementation easy. + wuThreshold uint32 + // Callback used to schedule window update. + scheduleWU func(uint32) } -// newLimit updates the inflow window to a new value n. -// It assumes that n is always greater than the old limit. -func (f *inFlow) newLimit(n uint32) uint32 { - f.mu.Lock() - d := n - f.limit - f.limit = n - f.mu.Unlock() - return d +// called by transport's reader goroutine to set a new limit on +// incoming flow control based on BDP estimation. +func (s *stInFlow) newLimit(n uint32) { + s.limit = n } -func (f *inFlow) maybeAdjust(n uint32) uint32 { - if n > uint32(math.MaxInt32) { - n = uint32(math.MaxInt32) +// called by transport's reader goroutine when data is received by it. +func (s *stInFlow) onData(n uint32) error { + rcvd := atomic.AddInt64(&s.rcvd, int64(n)) + if rcvd > int64(s.limit) { // Flow control violation. + return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, s.limit) } - f.mu.Lock() - // estSenderQuota is the receiver's view of the maximum number of bytes the sender - // can send without a window update. - estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate)) - // estUntransmittedData is the maximum number of bytes the sends might not have put - // on the wire yet. A value of 0 or less means that we have already received all or - // more bytes than the application is requesting to read. - estUntransmittedData := int32(n - f.pendingData) // Casting into int32 since it could be negative. - // This implies that unless we send a window update, the sender won't be able to send all the bytes - // for this message. Therefore we must send an update over the limit since there's an active read - // request from the application. - if estUntransmittedData > estSenderQuota { - // Sender's window shouldn't go more than 2^31 - 1 as specified in the HTTP spec. - if f.limit+n > maxWindowSize { - f.delta = maxWindowSize - f.limit - } else { - // Send a window update for the whole message and not just the difference between - // estUntransmittedData and estSenderQuota. This will be helpful in case the message - // is padded; We will fallback on the current available window(at least a 1/4th of the limit). - f.delta = n - } - f.mu.Unlock() - return f.delta - } - f.mu.Unlock() - return 0 -} - -// onData is invoked when some data frame is received. It updates pendingData. -func (f *inFlow) onData(n uint32) error { - f.mu.Lock() - f.pendingData += n - if f.pendingData+f.pendingUpdate > f.limit+f.delta { - limit := f.limit - rcvd := f.pendingData + f.pendingUpdate - f.mu.Unlock() - return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, limit) - } - f.mu.Unlock() return nil } -// onRead is invoked when the application reads the data. It returns the window size -// to be sent to the peer. -func (f *inFlow) onRead(n uint32) uint32 { - f.mu.Lock() - if f.pendingData == 0 { - f.mu.Unlock() - return 0 +// called by RPC's goroutine when data is read by it. +func (s *stInFlow) onRead(n uint32) { + s.read += n + if s.read >= s.wuThreshold { + val := atomic.AddInt64(&s.rcvd, ^int64(s.read-1)) + // Check if threshold needs to go up since limit might have gone up. + val += int64(s.read) + if val > int64(4*s.wuThreshold) { + s.wuThreshold = uint32(val / 4) + } + s.scheduleWU(s.read) + s.read = 0 } - f.pendingData -= n - if n > f.delta { - n -= f.delta - f.delta = 0 - } else { - f.delta -= n - n = 0 - } - f.pendingUpdate += n - if f.pendingUpdate >= f.limit/4 { - wu := f.pendingUpdate - f.pendingUpdate = 0 - f.mu.Unlock() - return wu - } - f.mu.Unlock() - return 0 } diff --git a/transport/handler_server.go b/transport/handler_server.go index a75338009..012d11448 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -38,6 +38,7 @@ import ( "golang.org/x/net/http2" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/msgdecoder" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" @@ -269,10 +270,10 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { } } -func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { +func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error { return ht.do(func() { ht.writeCommonHeaders(s) - ht.rw.Write(hdr) + ht.rw.Write(msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed)) ht.rw.Write(data) if !opts.Delay { ht.rw.(http.Flusher).Flush() @@ -337,16 +338,13 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace req := ht.req - s := &Stream{ - id: 0, // irrelevant - requestRead: func(int) {}, - cancel: cancel, - buf: newRecvBuffer(), - st: ht, - method: req.URL.Path, - recvCompress: req.Header.Get("grpc-encoding"), - contentSubtype: ht.contentSubtype, - } + s := newStream(ctx) + s.cancel = cancel + s.st = ht + s.method = req.URL.Path + s.recvCompress = req.Header.Get("grpc-encoding") + s.contentSubtype = ht.contentSubtype + pr := &peer.Peer{ Addr: ht.RemoteAddr(), } @@ -364,10 +362,6 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace } ht.stats.HandleRPC(s.ctx, inHeader) } - s.trReader = &transportReader{ - reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf}, - windowHandler: func(int) {}, - } // readerDone is closed when the Body.Read-ing goroutine exits. readerDone := make(chan struct{}) @@ -379,11 +373,11 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace for buf := make([]byte, readSize); ; { n, err := req.Body.Read(buf) if n > 0 { - s.buf.put(recvMsg{data: buf[:n:n]}) + s.consume(buf[:n:n], 0) buf = buf[n:] } if err != nil { - s.buf.put(recvMsg{err: mapRecvMsgError(err)}) + s.notifyErr(mapRecvMsgError(err)) return } if len(buf) == 0 { diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go index 3261b8e3d..c6d4c8917 100644 --- a/transport/handler_server_test.go +++ b/transport/handler_server_test.go @@ -423,7 +423,7 @@ func TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) { st.bodyw.Close() // no body st.ht.WriteStatus(s, status.New(codes.OK, "")) - st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{}) + st.ht.Write(s, []byte("data"), &Options{}) }) } diff --git a/transport/http2_client.go b/transport/http2_client.go index ab97cb571..8fa61bbbb 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -34,6 +34,7 @@ import ( "google.golang.org/grpc/channelz" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/msgdecoder" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -95,8 +96,9 @@ type http2Client struct { waitingStreams uint32 nextID uint32 - mu sync.Mutex // guard the following variables - state transportState + mu sync.Mutex // guard the following variables + state transportState + // TODO(mmukhi): Make this a sharded map. activeStreams map[uint32]*Stream // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. prevGoAwayID uint32 @@ -218,7 +220,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne goAway: make(chan struct{}), awakenKeepalive: make(chan struct{}, 1), framer: newFramer(conn, writeBufSize, readBufSize), - fc: &trInFlow{limit: uint32(icwz)}, scheme: scheme, activeStreams: make(map[uint32]*Stream), isSecure: isSecure, @@ -233,6 +234,15 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne streamsQuotaAvailable: make(chan struct{}, 1), } t.controlBuf = newControlBuffer(t.ctxDone) + t.fc = &trInFlow{ + limit: uint32(icwz), + scheduleWU: func(w uint32) { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + }, + } if opts.InitialWindowSize >= defaultWindowSize { t.initialWindowSize = opts.InitialWindowSize dynamicWindow = false @@ -306,33 +316,17 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne } func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { - // TODO(zhaoq): Handle uint32 overflow of Stream.id. - s := &Stream{ - done: make(chan struct{}), - method: callHdr.Method, - sendCompress: callHdr.SendCompress, - buf: newRecvBuffer(), - headerChan: make(chan struct{}), - contentSubtype: callHdr.ContentSubtype, - } - s.wq = newWriteQuota(defaultWriteQuota, s.done) - s.requestRead = func(n int) { - t.adjustWindow(s, uint32(n)) - } // The client side stream context should have exactly the same life cycle with the user provided context. // That means, s.ctx should be read-only. And s.ctx is done iff ctx is done. // So we use the original context here instead of creating a copy. - s.ctx = ctx - s.trReader = &transportReader{ - reader: &recvBufferReader{ - ctx: s.ctx, - ctxDone: s.ctx.Done(), - recv: s.buf, - }, - windowHandler: func(n int) { - t.updateWindow(s, uint32(n)) - }, - } + s := newStream(ctx) + // Initialize stream with client-side specific fields. + s.done = make(chan struct{}) + s.method = callHdr.Method + s.sendCompress = callHdr.SendCompress + s.headerChan = make(chan struct{}) + s.contentSubtype = callHdr.ContentSubtype + s.wq = newWriteQuota(defaultWriteQuota, s.done) return s } @@ -504,7 +498,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } // The stream was unprocessed by the server. atomic.StoreUint32(&s.unprocessed, 1) - s.write(recvMsg{err: err}) + s.notifyErr(err) close(s.done) // If headerChan isn't closed, then close it. if atomic.SwapUint32(&s.headerDone, 1) == 0 { @@ -572,7 +566,13 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea h.streamID = t.nextID t.nextID += 2 s.id = h.streamID - s.fc = &inFlow{limit: uint32(t.initialWindowSize)} + s.fc = &stInFlow{ + limit: uint32(t.initialWindowSize), + scheduleWU: func(w uint32) { + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) + }, + wuThreshold: uint32(t.initialWindowSize / 4), + } if t.streamQuota > 0 && t.waitingStreams > 0 { select { case t.streamsQuotaAvailable <- struct{}{}: @@ -642,7 +642,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. } if err != nil { // This will unblock reads eventually. - s.write(recvMsg{err: err}) + s.notifyErr(err) } // This will unblock write. close(s.done) @@ -740,7 +740,7 @@ func (t *http2Client) GracefulClose() error { // Write formats the data into HTTP2 data frame(s) and sends it out. The caller // should proceed only if Write returns nil. -func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { +func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { if opts.Last { // If it's the last message, update stream state. if !s.compareAndSwapState(streamActive, streamWriteDone) { @@ -753,7 +753,9 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e streamID: s.id, endStream: opts.Last, } - if hdr != nil || data != nil { // If it's not an empty data frame. + if data != nil { // If it's not an empty data frame. + // Get a gRPC-specific header for this message. + hdr := msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed) // Add some data to grpc message header so that we can equally // distribute bytes across frames. emptyLen := http2MaxFrameLen - len(hdr) @@ -778,39 +780,19 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) { return s, ok } -// adjustWindow sends out extra window update over the initial window size -// of stream if the application is requesting data larger in size than -// the window. -func (t *http2Client) adjustWindow(s *Stream, n uint32) { - if w := s.fc.maybeAdjust(n); w > 0 { - t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) - } -} - -// updateWindow adjusts the inbound quota for the stream. -// Window updates will be sent out when the cumulative quota -// exceeds the corresponding threshold. -func (t *http2Client) updateWindow(s *Stream, n uint32) { - if w := s.fc.onRead(n); w > 0 { - t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) - } -} - // updateFlowControl updates the incoming flow control windows // for the transport and the stream based on the current bdp // estimation. func (t *http2Client) updateFlowControl(n uint32) { - t.mu.Lock() - for _, s := range t.activeStreams { - s.fc.newLimit(n) - } - t.mu.Unlock() - updateIWS := func(interface{}) bool { + t.fc.newLimit(n) // Update transport's window. + updateIWS := func(interface{}) bool { // Update streams' windows. + // All future streams should see the + // updated value. t.initialWindowSize = int32(n) return true } - t.controlBuf.executeAndPut(updateIWS, &outgoingWindowUpdate{streamID: 0, increment: t.fc.newLimit(n)}) - t.controlBuf.put(&outgoingSettings{ + // Notify the other side of updated window. + t.controlBuf.executeAndPut(updateIWS, &outgoingSettings{ ss: []http2.Setting{ { ID: http2.SettingInitialWindowSize, @@ -818,13 +800,25 @@ func (t *http2Client) updateFlowControl(n uint32) { }, }, }) + t.mu.Lock() + // Update all the currently active streams. + for _, s := range t.activeStreams { + s.fc.newLimit(n) + } + t.mu.Unlock() } func (t *http2Client) handleData(f *http2.DataFrame) { size := f.Header().Length - var sendBDPPing bool - if t.bdpEst != nil { - sendBDPPing = t.bdpEst.add(size) + if size == 0 { + if f.StreamEnded() { + // The server has closed the stream without sending trailers. Record that + // the read direction is closed, and set the status appropriately. + if s, ok := t.getStream(f); ok { + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) + } + } + return } // Decouple connection's flow control from application's read. // An update on connection's flow control should not depend on @@ -835,53 +829,30 @@ func (t *http2Client) handleData(f *http2.DataFrame) { // active(fast) streams from starving in presence of slow or // inactive streams. // - if w := t.fc.onData(size); w > 0 { - t.controlBuf.put(&outgoingWindowUpdate{ - streamID: 0, - increment: w, - }) - } - if sendBDPPing { + t.fc.onData(size) + if t.bdpEst != nil && t.bdpEst.add(size) { // Avoid excessive ping detection (e.g. in an L7 proxy) // by sending a window update prior to the BDP ping. - - if w := t.fc.reset(); w > 0 { - t.controlBuf.put(&outgoingWindowUpdate{ - streamID: 0, - increment: w, - }) - } - + t.fc.reset() t.controlBuf.put(bdpPing) } + // Select the right stream to dispatch. - s, ok := t.getStream(f) - if !ok { - return - } - if size > 0 { - if err := s.fc.onData(size); err != nil { + if s, ok := t.getStream(f); ok { + d := f.Data() + padding := 0 + if f.Header().Flags.Has(http2.FlagDataPadded) { + padding = int(size) - len(d) + } + if err := s.consume(d, padding); err != nil { t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false) return } - if f.Header().Flags.Has(http2.FlagDataPadded) { - if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { - t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) - } + if f.StreamEnded() { + // The server has closed the stream without sending trailers. Record that + // the read direction is closed, and set the status appropriately. + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) } - // TODO(bradfitz, zhaoq): A copy is required here because there is no - // guarantee f.Data() is consumed before the arrival of next frame. - // Can this copy be eliminated? - if len(f.Data()) > 0 { - data := make([]byte, len(f.Data())) - copy(data, f.Data()) - s.write(recvMsg{data: data}) - } - } - // The server has closed the stream without sending trailers. Record that - // the read direction is closed, and set the status appropriately. - if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { - t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) } } @@ -890,6 +861,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { if !ok { return } + errorf("transport: client got RST_STREAM with error %v, for stream: %d", f.ErrCode, s.id) if f.ErrCode == http2.ErrCodeRefusedStream { // The stream was unprocessed by the server. atomic.StoreUint32(&s.unprocessed, 1) diff --git a/transport/http2_server.go b/transport/http2_server.go index 8b93e222e..7a4b2892d 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -39,6 +39,7 @@ import ( "google.golang.org/grpc/channelz" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/msgdecoder" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -212,7 +213,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err writerDone: make(chan struct{}), maxStreams: maxStreams, inTapHandle: config.InTapHandle, - fc: &trInFlow{limit: uint32(icwz)}, state: reachable, activeStreams: make(map[uint32]*Stream), stats: config.StatsHandler, @@ -222,6 +222,15 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err initialWindowSize: iwz, } t.controlBuf = newControlBuffer(t.ctxDone) + t.fc = &trInFlow{ + limit: uint32(icwz), + scheduleWU: func(w uint32) { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + }, + } if dynamicWindow { t.bdpEst = &bdpEstimator{ bdp: initialWindowSize, @@ -298,25 +307,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( return } } - - buf := newRecvBuffer() - s := &Stream{ - id: streamID, - st: t, - buf: buf, - fc: &inFlow{limit: uint32(t.initialWindowSize)}, - recvCompress: state.encoding, - method: state.method, - contentSubtype: state.contentSubtype, - } - if frame.StreamEnded() { - // s is just created by the caller. No lock needed. - s.state = streamReadDone - } + var ( + ctx context.Context + cancel func() + ) if state.timeoutSet { - s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout) + ctx, cancel = context.WithTimeout(t.ctx, state.timeout) } else { - s.ctx, s.cancel = context.WithCancel(t.ctx) + ctx, cancel = context.WithCancel(t.ctx) } pr := &peer.Peer{ Addr: t.remoteAddr, @@ -325,34 +323,55 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( if t.authInfo != nil { pr.AuthInfo = t.authInfo } - s.ctx = peer.NewContext(s.ctx, pr) + ctx = peer.NewContext(ctx, pr) // Attach the received metadata to the context. if len(state.mdata) > 0 { - s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata) + ctx = metadata.NewIncomingContext(ctx, state.mdata) } if state.statsTags != nil { - s.ctx = stats.SetIncomingTags(s.ctx, state.statsTags) + ctx = stats.SetIncomingTags(ctx, state.statsTags) } if state.statsTrace != nil { - s.ctx = stats.SetIncomingTrace(s.ctx, state.statsTrace) + ctx = stats.SetIncomingTrace(ctx, state.statsTrace) } if t.inTapHandle != nil { var err error info := &tap.Info{ FullMethodName: state.method, } - s.ctx, err = t.inTapHandle(s.ctx, info) + ctx, err = t.inTapHandle(ctx, info) if err != nil { warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) t.controlBuf.put(&cleanupStream{ - streamID: s.id, + streamID: streamID, rst: true, rstCode: http2.ErrCodeRefusedStream, onWrite: func() {}, }) + cancel() return } } + ctx = traceCtx(ctx, state.method) + s := newStream(ctx) + // Initialize s with server-side specific fields. + s.cancel = cancel + s.id = streamID + s.st = t + s.fc = &stInFlow{ + limit: uint32(t.initialWindowSize), + scheduleWU: func(w uint32) { + t.controlBuf.put(&outgoingWindowUpdate{streamID: streamID, increment: w}) + }, + wuThreshold: uint32(t.initialWindowSize / 4), + } + s.recvCompress = state.encoding + s.method = state.method + s.contentSubtype = state.contentSubtype + if frame.StreamEnded() { + s.state = streamReadDone + } + s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) t.mu.Lock() if t.state != reachable { t.mu.Unlock() @@ -386,10 +405,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( t.lastStreamCreated = time.Now() t.czmu.Unlock() } - s.requestRead = func(n int) { - t.adjustWindow(s, uint32(n)) - } - s.ctx = traceCtx(s.ctx, s.method) if t.stats != nil { s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) inHeader := &stats.InHeader{ @@ -401,18 +416,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } t.stats.HandleRPC(s.ctx, inHeader) } - s.ctxDone = s.ctx.Done() - s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) - s.trReader = &transportReader{ - reader: &recvBufferReader{ - ctx: s.ctx, - ctxDone: s.ctxDone, - recv: s.buf, - }, - windowHandler: func(n int) { - t.updateWindow(s, uint32(n)) - }, - } handle(s) return } @@ -490,41 +493,20 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) { return s, true } -// adjustWindow sends out extra window update over the initial window size -// of stream if the application is requesting data larger in size than -// the window. -func (t *http2Server) adjustWindow(s *Stream, n uint32) { - if w := s.fc.maybeAdjust(n); w > 0 { - t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) - } - -} - -// updateWindow adjusts the inbound quota for the stream and the transport. -// Window updates will deliver to the controller for sending when -// the cumulative quota exceeds the corresponding threshold. -func (t *http2Server) updateWindow(s *Stream, n uint32) { - if w := s.fc.onRead(n); w > 0 { - t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, - increment: w, - }) - } -} - // updateFlowControl updates the incoming flow control windows // for the transport and the stream based on the current bdp // estimation. func (t *http2Server) updateFlowControl(n uint32) { t.mu.Lock() + // Update all the current streams' window. for _, s := range t.activeStreams { s.fc.newLimit(n) } + // Update all the future streams' window. t.initialWindowSize = int32(n) t.mu.Unlock() - t.controlBuf.put(&outgoingWindowUpdate{ - streamID: 0, - increment: t.fc.newLimit(n), - }) + t.fc.newLimit(n) // Update transport's window. + // Notify the other side of the updated value. t.controlBuf.put(&outgoingSettings{ ss: []http2.Setting{ { @@ -538,9 +520,15 @@ func (t *http2Server) updateFlowControl(n uint32) { func (t *http2Server) handleData(f *http2.DataFrame) { size := f.Header().Length - var sendBDPPing bool - if t.bdpEst != nil { - sendBDPPing = t.bdpEst.add(size) + if size == 0 { + if f.StreamEnded() { + if s, ok := t.getStream(f); ok { + // Received the end of stream from the client. + s.compareAndSwapState(streamActive, streamReadDone) + s.notifyErr(io.EOF) + } + } + return } // Decouple connection's flow control from application's read. // An update on connection's flow control should not depend on @@ -550,51 +538,30 @@ func (t *http2Server) handleData(f *http2.DataFrame) { // Decoupling the connection flow control will prevent other // active(fast) streams from starving in presence of slow or // inactive streams. - if w := t.fc.onData(size); w > 0 { - t.controlBuf.put(&outgoingWindowUpdate{ - streamID: 0, - increment: w, - }) - } - if sendBDPPing { + t.fc.onData(size) + if t.bdpEst != nil && t.bdpEst.add(size) { // Avoid excessive ping detection (e.g. in an L7 proxy) // by sending a window update prior to the BDP ping. - if w := t.fc.reset(); w > 0 { - t.controlBuf.put(&outgoingWindowUpdate{ - streamID: 0, - increment: w, - }) - } + t.fc.reset() t.controlBuf.put(bdpPing) } // Select the right stream to dispatch. - s, ok := t.getStream(f) - if !ok { - return - } - if size > 0 { - if err := s.fc.onData(size); err != nil { + if s, ok := t.getStream(f); ok { + d := f.Data() + padding := 0 + if f.Header().Flags.Has(http2.FlagDataPadded) { + padding = int(size) - len(d) + } + if err := s.consume(d, padding); err != nil { + errorf("transport: flow control error on server: %v", err) t.closeStream(s, true, http2.ErrCodeFlowControl, nil, false) return } - if f.Header().Flags.Has(http2.FlagDataPadded) { - if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { - t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) - } + if f.StreamEnded() { + // Received the end of stream from the client. + s.compareAndSwapState(streamActive, streamReadDone) + s.notifyErr(io.EOF) } - // TODO(bradfitz, zhaoq): A copy is required here because there is no - // guarantee f.Data() is consumed before the arrival of next frame. - // Can this copy be eliminated? - if len(f.Data()) > 0 { - data := make([]byte, len(f.Data())) - copy(data, f.Data()) - s.write(recvMsg{data: data}) - } - } - if f.Header().Flags.Has(http2.FlagDataEndStream) { - // Received the end of stream from the client. - s.compareAndSwapState(streamActive, streamReadDone) - s.write(recvMsg{err: io.EOF}) } } @@ -792,7 +759,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { // Write converts the data into HTTP2 data frame and sends it out. Non-nil error // is returns if it fails (e.g., framing error, transport error). -func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { +func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { if !s.headerOk { // Headers haven't been written yet. if err := t.WriteHeader(s, nil); err != nil { // TODO(mmukhi, dfawley): Make sure this is the right code to return. @@ -811,6 +778,8 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e return ContextErr(s.ctx.Err()) } } + // Get a gRPC-specific header for this message. + hdr := msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed) // Add some data to header frame so that we can equally distribute bytes across frames. emptyLen := http2MaxFrameLen - len(hdr) if emptyLen > len(data) { diff --git a/transport/stream.go b/transport/stream.go new file mode 100644 index 000000000..bfc4edc43 --- /dev/null +++ b/transport/stream.go @@ -0,0 +1,407 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "fmt" + "io" + "sync" + "sync/atomic" + + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/msgdecoder" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +const maxInt = int(^uint(0) >> 1) + +type streamState uint32 + +const ( + streamActive streamState = iota + streamWriteDone // EndStream sent + streamReadDone // EndStream received + streamDone // the entire stream is finished. +) + +// transport's reader goroutine adds msgdecoder.RecvMsg to it which are later +// read by RPC's reading goroutine. +// +// It is protected by a lock in the Stream that owns it. +type recvBuffer struct { + ctx context.Context + ctxDone <-chan struct{} + c chan *msgdecoder.RecvMsg + mu sync.Mutex + waiting bool + list *msgdecoder.RecvMsgList +} + +func newRecvBuffer(ctx context.Context, ctxDone <-chan struct{}) *recvBuffer { + return &recvBuffer{ + ctx: ctx, + ctxDone: ctxDone, + c: make(chan *msgdecoder.RecvMsg, 1), + list: &msgdecoder.RecvMsgList{}, + } +} + +// put adds r to the underlying list if there's no consumer +// waiting, otherwise, it writes on the chan directly. +func (b *recvBuffer) put(r *msgdecoder.RecvMsg) { + b.mu.Lock() + if b.waiting { + b.waiting = false + b.mu.Unlock() + b.c <- r + return + } + b.list.Enqueue(r) + b.mu.Unlock() +} + +// getNoBlock returns a msgdecoder.RecvMsg and true status, if there's +// any available. +// If the status is false, the caller must then call +// getWithBlock() before calling getNoBlock() again. +func (b *recvBuffer) getNoBlock() (*msgdecoder.RecvMsg, bool) { + b.mu.Lock() + r := b.list.Dequeue() + if r != nil { + b.mu.Unlock() + return r, true + } + b.waiting = true + b.mu.Unlock() + return nil, false + +} + +// getWithBlock() blocks until a complete message has been +// received, or an error has occurred or the underlying +// context has expired. +// It must only be called after having called GetNoBlock() +// once. +func (b *recvBuffer) getWithBlock() (*msgdecoder.RecvMsg, error) { + select { + case <-b.ctxDone: + return nil, ContextErr(b.ctx.Err()) + case r := <-b.c: + return r, nil + } +} + +func (b *recvBuffer) get() (*msgdecoder.RecvMsg, error) { + m, ok := b.getNoBlock() + if ok { + return m, nil + } + m, err := b.getWithBlock() + if err != nil { + return nil, err + } + return m, nil +} + +// Stream represents an RPC in the transport layer. +type Stream struct { + id uint32 + st ServerTransport // nil for client side Stream + ctx context.Context // the associated context of the stream + cancel context.CancelFunc // always nil for client side Stream + done chan struct{} // closed at the end of stream to unblock writers. On the client side. + ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) + method string // the associated RPC method of the stream + recvCompress string + sendCompress string + buf *recvBuffer + wq *writeQuota + + headerChan chan struct{} // closed to indicate the end of header metadata. + headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. + header metadata.MD // the received header metadata. + trailer metadata.MD // the key-value map of trailer metadata. + + headerOk bool // becomes true from the first header is about to send + state streamState + + status *status.Status // the status error received from the server + + bytesReceived uint32 // indicates whether any bytes have been received on this stream + unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream + + // contentSubtype is the content-subtype for requests. + // this must be lowercase or the behavior is undefined. + contentSubtype string + + rbuf *recvBuffer + fc *stInFlow + msgDecoder *msgdecoder.MessageDecoder + readErr error +} + +func newStream(ctx context.Context) *Stream { + // Cache the done chan since a Done() call is expensive. + ctxDone := ctx.Done() + s := &Stream{ + ctx: ctx, + ctxDone: ctxDone, + rbuf: newRecvBuffer(ctx, ctxDone), + } + dispatch := func(r *msgdecoder.RecvMsg) { + s.rbuf.put(r) + } + s.msgDecoder = msgdecoder.NewMessageDecoder(dispatch) + return s +} + +// notifyErr notifies RPC of an error seen by the transport. +// +// Note to developers: This call can unblock Read calls on RPC +// and lead to reading of unprotected fields on stream on the +// client-side. It should only be called from inside +// transport.closeStream() if the stream was initialized or from +// inside the cleanup callback if the stream was not initialized. +func (s *Stream) notifyErr(err error) { + s.rbuf.put(&msgdecoder.RecvMsg{Err: err}) +} + +// consume is called by transport's reader goroutine for parsing +// and decoding data received for this stream. +func (s *Stream) consume(b []byte, padding int) error { + // Flow control check. + if s.fc != nil { // HandlerServer doesn't use our flow control. + if err := s.fc.onData(uint32(len(b) + padding)); err != nil { + return err + } + } + s.msgDecoder.Decode(b, padding) + return nil +} + +// Read reads one whole message from the transport. +// It is called by RPC's goroutine. +// It is not safe to be called concurrently by multiple goroutines. +// +// Returns: +// 1. received message's compression status(true if was compressed) +// 2. Message as a byte slice +// 3. Error, if any. +func (s *Stream) Read(maxRecvMsgSize int) (bool, []byte, error) { + if s.readErr != nil { + return false, nil, s.readErr + } + var ( + m *msgdecoder.RecvMsg + err error + ) + // First read the underlying message header + if m, err = s.rbuf.get(); err != nil { + s.readErr = err + return false, nil, err + } + if m.Err != nil { + s.readErr = m.Err + return false, nil, s.readErr + } + // Make sure the message being received isn't too large. + if int64(m.Length) > int64(maxInt) { + s.readErr = status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", m.Length, maxInt) + return false, nil, s.readErr + } + if m.Length > maxRecvMsgSize { + s.readErr = status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", m.Length, maxRecvMsgSize) + return false, nil, s.readErr + } + // Send a window update for the message this RPC is reading. + if s.fc != nil { // HanderServer doesn't use our flow control. + s.fc.onRead(uint32(m.Length + m.Overhead)) + } + isCompressed := m.IsCompressed + // Read the message. + if m, err = s.rbuf.get(); err != nil { + s.readErr = err + return false, nil, err + } + if m.Err != nil { + if m.Err == io.EOF { + m.Err = io.ErrUnexpectedEOF + } + s.readErr = m.Err + return false, nil, s.readErr + } + return isCompressed, m.Data, nil +} + +func (s *Stream) swapState(st streamState) streamState { + return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st))) +} + +func (s *Stream) compareAndSwapState(oldState, newState streamState) bool { + return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState)) +} + +func (s *Stream) getState() streamState { + return streamState(atomic.LoadUint32((*uint32)(&s.state))) +} + +func (s *Stream) waitOnHeader() error { + if s.headerChan == nil { + // On the server headerChan is always nil since a stream originates + // only after having received headers. + return nil + } + select { + case <-s.ctx.Done(): + return ContextErr(s.ctx.Err()) + case <-s.headerChan: + return nil + } +} + +// RecvCompress returns the compression algorithm applied to the inbound +// message. It is empty string if there is no compression applied. +func (s *Stream) RecvCompress() string { + if err := s.waitOnHeader(); err != nil { + return "" + } + return s.recvCompress +} + +// SetSendCompress sets the compression algorithm to the stream. +func (s *Stream) SetSendCompress(str string) { + s.sendCompress = str +} + +// Done returns a chanel which is closed when it receives the final status +// from the server. +func (s *Stream) Done() <-chan struct{} { + return s.done +} + +// Header acquires the key-value pairs of header metadata once it +// is available. It blocks until i) the metadata is ready or ii) there is no +// header metadata or iii) the stream is canceled/expired. +func (s *Stream) Header() (metadata.MD, error) { + err := s.waitOnHeader() + // Even if the stream is closed, header is returned if available. + select { + case <-s.headerChan: + if s.header == nil { + return nil, nil + } + return s.header.Copy(), nil + default: + } + return nil, err +} + +// Trailer returns the cached trailer metedata. Note that if it is not called +// after the entire stream is done, it could return an empty MD. Client +// side only. +// It can be safely read only after stream has ended that is either read +// or write have returned io.EOF. +func (s *Stream) Trailer() metadata.MD { + c := s.trailer.Copy() + return c +} + +// ServerTransport returns the underlying ServerTransport for the stream. +// The client side stream always returns nil. +func (s *Stream) ServerTransport() ServerTransport { + return s.st +} + +// ContentSubtype returns the content-subtype for a request. For example, a +// content-subtype of "proto" will result in a content-type of +// "application/grpc+proto". This will always be lowercase. See +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. +func (s *Stream) ContentSubtype() string { + return s.contentSubtype +} + +// Context returns the context of the stream. +func (s *Stream) Context() context.Context { + return s.ctx +} + +// Method returns the method for the stream. +func (s *Stream) Method() string { + return s.method +} + +// Status returns the status received from the server. +// Status can be read safely only after the stream has ended, +// that is, read or write has returned io.EOF. +func (s *Stream) Status() *status.Status { + return s.status +} + +// SetHeader sets the header metadata. This can be called multiple times. +// Server side only. +// This should not be called in parallel to other data writes. +func (s *Stream) SetHeader(md metadata.MD) error { + if md.Len() == 0 { + return nil + } + if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) { + return ErrIllegalHeaderWrite + } + s.header = metadata.Join(s.header, md) + return nil +} + +// SendHeader sends the given header metadata. The given metadata is +// combined with any metadata set by previous calls to SetHeader and +// then written to the transport stream. +func (s *Stream) SendHeader(md metadata.MD) error { + t := s.ServerTransport() + return t.WriteHeader(s, md) +} + +// SetTrailer sets the trailer metadata which will be sent with the RPC status +// by the server. This can be called multiple times. Server side only. +// This should not be called parallel to other data writes. +func (s *Stream) SetTrailer(md metadata.MD) error { + if md.Len() == 0 { + return nil + } + s.trailer = metadata.Join(s.trailer, md) + return nil +} + +// BytesReceived indicates whether any bytes have been received on this stream. +func (s *Stream) BytesReceived() bool { + return atomic.LoadUint32(&s.bytesReceived) == 1 +} + +// Unprocessed indicates whether the server did not process this stream -- +// i.e. it sent a refused stream or GOAWAY including this stream ID. +func (s *Stream) Unprocessed() bool { + return atomic.LoadUint32(&s.unprocessed) == 1 +} + +// GoString is implemented by Stream so context.String() won't +// race when printing %#v. +func (s *Stream) GoString() string { + return fmt.Sprintf("", s, s.method) +} diff --git a/transport/transport.go b/transport/transport.go index 2f643a3d0..6f2152e0b 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -24,10 +24,7 @@ package transport // externally used as import "google.golang.org/grpc/transport import ( "errors" "fmt" - "io" "net" - "sync" - "sync/atomic" "golang.org/x/net/context" "google.golang.org/grpc/codes" @@ -39,359 +36,6 @@ import ( "google.golang.org/grpc/tap" ) -// recvMsg represents the received msg from the transport. All transport -// protocol specific info has been removed. -type recvMsg struct { - data []byte - // nil: received some data - // io.EOF: stream is completed. data is nil. - // other non-nil error: transport failure. data is nil. - err error -} - -// recvBuffer is an unbounded channel of recvMsg structs. -// Note recvBuffer differs from controlBuffer only in that recvBuffer -// holds a channel of only recvMsg structs instead of objects implementing "item" interface. -// recvBuffer is written to much more often than -// controlBuffer and using strict recvMsg structs helps avoid allocation in "recvBuffer.put" -type recvBuffer struct { - c chan recvMsg - mu sync.Mutex - backlog []recvMsg - err error -} - -func newRecvBuffer() *recvBuffer { - b := &recvBuffer{ - c: make(chan recvMsg, 1), - } - return b -} - -func (b *recvBuffer) put(r recvMsg) { - b.mu.Lock() - if b.err != nil { - b.mu.Unlock() - // An error had occurred earlier, don't accept more - // data or errors. - return - } - b.err = r.err - if len(b.backlog) == 0 { - select { - case b.c <- r: - b.mu.Unlock() - return - default: - } - } - b.backlog = append(b.backlog, r) - b.mu.Unlock() -} - -func (b *recvBuffer) load() { - b.mu.Lock() - if len(b.backlog) > 0 { - select { - case b.c <- b.backlog[0]: - b.backlog[0] = recvMsg{} - b.backlog = b.backlog[1:] - default: - } - } - b.mu.Unlock() -} - -// get returns the channel that receives a recvMsg in the buffer. -// -// Upon receipt of a recvMsg, the caller should call load to send another -// recvMsg onto the channel if there is any. -func (b *recvBuffer) get() <-chan recvMsg { - return b.c -} - -// -// recvBufferReader implements io.Reader interface to read the data from -// recvBuffer. -type recvBufferReader struct { - ctx context.Context - ctxDone <-chan struct{} // cache of ctx.Done() (for performance). - recv *recvBuffer - last []byte // Stores the remaining data in the previous calls. - err error -} - -// Read reads the next len(p) bytes from last. If last is drained, it tries to -// read additional data from recv. It blocks if there no additional data available -// in recv. If Read returns any non-nil error, it will continue to return that error. -func (r *recvBufferReader) Read(p []byte) (n int, err error) { - if r.err != nil { - return 0, r.err - } - n, r.err = r.read(p) - return n, r.err -} - -func (r *recvBufferReader) read(p []byte) (n int, err error) { - if r.last != nil && len(r.last) > 0 { - // Read remaining data left in last call. - copied := copy(p, r.last) - r.last = r.last[copied:] - return copied, nil - } - select { - case <-r.ctxDone: - return 0, ContextErr(r.ctx.Err()) - case m := <-r.recv.get(): - r.recv.load() - if m.err != nil { - return 0, m.err - } - copied := copy(p, m.data) - r.last = m.data[copied:] - return copied, nil - } -} - -type streamState uint32 - -const ( - streamActive streamState = iota - streamWriteDone // EndStream sent - streamReadDone // EndStream received - streamDone // the entire stream is finished. -) - -// Stream represents an RPC in the transport layer. -type Stream struct { - id uint32 - st ServerTransport // nil for client side Stream - ctx context.Context // the associated context of the stream - cancel context.CancelFunc // always nil for client side Stream - done chan struct{} // closed at the end of stream to unblock writers. On the client side. - ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance) - method string // the associated RPC method of the stream - recvCompress string - sendCompress string - buf *recvBuffer - trReader io.Reader - fc *inFlow - recvQuota uint32 - wq *writeQuota - - // Callback to state application's intentions to read data. This - // is used to adjust flow control, if needed. - requestRead func(int) - - headerChan chan struct{} // closed to indicate the end of header metadata. - headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. - header metadata.MD // the received header metadata. - trailer metadata.MD // the key-value map of trailer metadata. - - headerOk bool // becomes true from the first header is about to send - state streamState - - status *status.Status // the status error received from the server - - bytesReceived uint32 // indicates whether any bytes have been received on this stream - unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream - - // contentSubtype is the content-subtype for requests. - // this must be lowercase or the behavior is undefined. - contentSubtype string -} - -func (s *Stream) swapState(st streamState) streamState { - return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st))) -} - -func (s *Stream) compareAndSwapState(oldState, newState streamState) bool { - return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState)) -} - -func (s *Stream) getState() streamState { - return streamState(atomic.LoadUint32((*uint32)(&s.state))) -} - -func (s *Stream) waitOnHeader() error { - if s.headerChan == nil { - // On the server headerChan is always nil since a stream originates - // only after having received headers. - return nil - } - select { - case <-s.ctx.Done(): - return ContextErr(s.ctx.Err()) - case <-s.headerChan: - return nil - } -} - -// RecvCompress returns the compression algorithm applied to the inbound -// message. It is empty string if there is no compression applied. -func (s *Stream) RecvCompress() string { - if err := s.waitOnHeader(); err != nil { - return "" - } - return s.recvCompress -} - -// SetSendCompress sets the compression algorithm to the stream. -func (s *Stream) SetSendCompress(str string) { - s.sendCompress = str -} - -// Done returns a chanel which is closed when it receives the final status -// from the server. -func (s *Stream) Done() <-chan struct{} { - return s.done -} - -// Header acquires the key-value pairs of header metadata once it -// is available. It blocks until i) the metadata is ready or ii) there is no -// header metadata or iii) the stream is canceled/expired. -func (s *Stream) Header() (metadata.MD, error) { - err := s.waitOnHeader() - // Even if the stream is closed, header is returned if available. - select { - case <-s.headerChan: - if s.header == nil { - return nil, nil - } - return s.header.Copy(), nil - default: - } - return nil, err -} - -// Trailer returns the cached trailer metedata. Note that if it is not called -// after the entire stream is done, it could return an empty MD. Client -// side only. -// It can be safely read only after stream has ended that is either read -// or write have returned io.EOF. -func (s *Stream) Trailer() metadata.MD { - c := s.trailer.Copy() - return c -} - -// ServerTransport returns the underlying ServerTransport for the stream. -// The client side stream always returns nil. -func (s *Stream) ServerTransport() ServerTransport { - return s.st -} - -// ContentSubtype returns the content-subtype for a request. For example, a -// content-subtype of "proto" will result in a content-type of -// "application/grpc+proto". This will always be lowercase. See -// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for -// more details. -func (s *Stream) ContentSubtype() string { - return s.contentSubtype -} - -// Context returns the context of the stream. -func (s *Stream) Context() context.Context { - return s.ctx -} - -// Method returns the method for the stream. -func (s *Stream) Method() string { - return s.method -} - -// Status returns the status received from the server. -// Status can be read safely only after the stream has ended, -// that is, read or write has returned io.EOF. -func (s *Stream) Status() *status.Status { - return s.status -} - -// SetHeader sets the header metadata. This can be called multiple times. -// Server side only. -// This should not be called in parallel to other data writes. -func (s *Stream) SetHeader(md metadata.MD) error { - if md.Len() == 0 { - return nil - } - if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) { - return ErrIllegalHeaderWrite - } - s.header = metadata.Join(s.header, md) - return nil -} - -// SendHeader sends the given header metadata. The given metadata is -// combined with any metadata set by previous calls to SetHeader and -// then written to the transport stream. -func (s *Stream) SendHeader(md metadata.MD) error { - t := s.ServerTransport() - return t.WriteHeader(s, md) -} - -// SetTrailer sets the trailer metadata which will be sent with the RPC status -// by the server. This can be called multiple times. Server side only. -// This should not be called parallel to other data writes. -func (s *Stream) SetTrailer(md metadata.MD) error { - if md.Len() == 0 { - return nil - } - s.trailer = metadata.Join(s.trailer, md) - return nil -} - -func (s *Stream) write(m recvMsg) { - s.buf.put(m) -} - -// Read reads all p bytes from the wire for this stream. -func (s *Stream) Read(p []byte) (n int, err error) { - // Don't request a read if there was an error earlier - if er := s.trReader.(*transportReader).er; er != nil { - return 0, er - } - s.requestRead(len(p)) - return io.ReadFull(s.trReader, p) -} - -// tranportReader reads all the data available for this Stream from the transport and -// passes them into the decoder, which converts them into a gRPC message stream. -// The error is io.EOF when the stream is done or another non-nil error if -// the stream broke. -type transportReader struct { - reader io.Reader - // The handler to control the window update procedure for both this - // particular stream and the associated transport. - windowHandler func(int) - er error -} - -func (t *transportReader) Read(p []byte) (n int, err error) { - n, err = t.reader.Read(p) - if err != nil { - t.er = err - return - } - t.windowHandler(n) - return -} - -// BytesReceived indicates whether any bytes have been received on this stream. -func (s *Stream) BytesReceived() bool { - return atomic.LoadUint32(&s.bytesReceived) == 1 -} - -// Unprocessed indicates whether the server did not process this stream -- -// i.e. it sent a refused stream or GOAWAY including this stream ID. -func (s *Stream) Unprocessed() bool { - return atomic.LoadUint32(&s.unprocessed) == 1 -} - -// GoString is implemented by Stream so context.String() won't -// race when printing %#v. -func (s *Stream) GoString() string { - return fmt.Sprintf("", s, s.method) -} - // state of transport type transportState int @@ -476,7 +120,13 @@ type Options struct { // Delay is a hint to the transport implementation for whether // the data could be buffered for a batching write. The // transport implementation may ignore the hint. + // TODO(mmukhi, dfawley): Should this be deleted? Delay bool + + // IsCompressed indicates weather the message being written + // was compressed or not. Transport relays this information + // to the API that generates gRPC-specific message header. + IsCompressed bool } // CallHdr carries the information of a particular RPC. @@ -525,7 +175,7 @@ type ClientTransport interface { // Write sends the data for the given stream. A nil stream indicates // the write is to be performed on the transport as a whole. - Write(s *Stream, hdr []byte, data []byte, opts *Options) error + Write(s *Stream, data []byte, opts *Options) error // NewStream creates a Stream for an RPC. NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) @@ -573,7 +223,7 @@ type ServerTransport interface { // Write sends the data for the given stream. // Write may not be called on all streams. - Write(s *Stream, hdr []byte, data []byte, opts *Options) error + Write(s *Stream, data []byte, opts *Options) error // WriteStatus sends the status of a stream to the client. WriteStatus is // the final call made on a stream and always occurs. diff --git a/transport/transport_test.go b/transport/transport_test.go index e0201efa8..07a81fde2 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -100,8 +100,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { req = expectedRequestLarge resp = expectedResponseLarge } - p := make([]byte, len(req)) - _, err := s.Read(p) + _, p, err := s.Read(math.MaxInt32) if err != nil { return } @@ -109,31 +108,26 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { t.Fatalf("handleStream got %v, want %v", p, req) } // send a response back to the client. - h.t.Write(s, nil, resp, &Options{}) + h.t.Write(s, resp, &Options{}) // send the trailer to end the stream. h.t.WriteStatus(s, status.New(codes.OK, "")) } func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) { - header := make([]byte, 5) for { - if _, err := s.Read(header); err != nil { + _, msg, err := s.Read(math.MaxInt32) + if err != nil { if err == io.EOF { h.t.WriteStatus(s, status.New(codes.OK, "")) return } - t.Fatalf("Error on server while reading data header: %v", err) + t.Errorf("Error on server while reading data header: %v", err) + return } - sz := binary.BigEndian.Uint32(header[1:]) - msg := make([]byte, int(sz)) - if _, err := s.Read(msg); err != nil { - t.Fatalf("Error on server while reading message: %v", err) + if err := h.t.Write(s, msg, &Options{}); err != nil { + t.Errorf("Error on server while writing: %v", err) + return } - buf := make([]byte, sz+5) - buf[0] = byte(0) - binary.BigEndian.PutUint32(buf[1:], uint32(sz)) - copy(buf[5:], msg) - h.t.Write(s, nil, buf, &Options{}) } } @@ -189,12 +183,10 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { req = expectedRequestLarge resp = expectedResponseLarge } - p := make([]byte, len(req)) - // Wait before reading. Give time to client to start sending // before server starts reading. time.Sleep(2 * time.Second) - _, err := s.Read(p) + _, p, err := s.Read(math.MaxInt32) if err != nil { t.Errorf("s.Read(_) = _, %v, want _, ", err) return @@ -205,7 +197,7 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { return } // send a response back to the client. - if err := h.t.Write(s, nil, resp, &Options{}); err != nil { + if err := h.t.Write(s, resp, &Options{}); err != nil { t.Errorf("server Write got %v, want ", err) return } @@ -223,8 +215,7 @@ func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) { req = expectedRequestLarge resp = expectedResponseLarge } - p := make([]byte, len(req)) - _, err := s.Read(p) + _, p, err := s.Read(math.MaxInt32) if err != nil { t.Errorf("s.Read(_) = _, %v, want _, ", err) return @@ -237,7 +228,7 @@ func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) { // Wait before sending. Give time to client to start reading // before server starts sending. time.Sleep(2 * time.Second) - if err := h.t.Write(s, nil, resp, &Options{}); err != nil { + if err := h.t.Write(s, resp, &Options{}); err != nil { t.Errorf("server Write got %v, want ", err) return } @@ -442,7 +433,7 @@ func TestInflightStreamClosing(t *testing.T) { serr := StreamError{Desc: "client connection is closing"} go func() { defer close(donec) - if _, err := stream.Read(make([]byte, defaultWindowSize)); err != serr { + if _, _, err := stream.Read(math.MaxInt32); err != serr { t.Errorf("unexpected Stream error %v, expected %v", err, serr) } }() @@ -858,15 +849,14 @@ func TestClientSendAndReceive(t *testing.T) { Last: true, Delay: false, } - if err := ct.Write(s1, nil, expectedRequest, &opts); err != nil && err != io.EOF { + if err := ct.Write(s1, expectedRequest, &opts); err != nil && err != io.EOF { t.Fatalf("failed to send data: %v", err) } - p := make([]byte, len(expectedResponse)) - _, recvErr := s1.Read(p) + _, p, recvErr := s1.Read(math.MaxInt32) if recvErr != nil || !bytes.Equal(p, expectedResponse) { t.Fatalf("Error: %v, want ; Result: %v, want %v", recvErr, p, expectedResponse) } - _, recvErr = s1.Read(p) + _, _, recvErr = s1.Read(math.MaxInt32) if recvErr != io.EOF { t.Fatalf("Error: %v; want ", recvErr) } @@ -895,16 +885,15 @@ func performOneRPC(ct ClientTransport) { Last: true, Delay: false, } - if err := ct.Write(s, []byte{}, expectedRequest, &opts); err == nil || err == io.EOF { + if err := ct.Write(s, expectedRequest, &opts); err == nil || err == io.EOF { time.Sleep(5 * time.Millisecond) // The following s.Recv()'s could error out because the // underlying transport is gone. // // Read response - p := make([]byte, len(expectedResponse)) - s.Read(p) + s.Read(math.MaxInt32) // Read io.EOF - s.Read(p) + s.Read(math.MaxInt32) } } @@ -939,14 +928,13 @@ func TestLargeMessage(t *testing.T) { if err != nil { t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) } - if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { + if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) } - p := make([]byte, len(expectedResponseLarge)) - if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { - t.Errorf("s.Read(%v) = _, %v, want %v, ", err, p, expectedResponse) + if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) { + t.Errorf("s.Read(math.MaxInt32) = %v, %v, want %v, ", p, err, expectedResponse) } - if _, err = s.Read(p); err != io.EOF { + if _, _, err = s.Read(math.MaxInt32); err != io.EOF { t.Errorf("Failed to complete the stream %v; want ", err) } }() @@ -974,19 +962,18 @@ func TestLargeMessageWithDelayRead(t *testing.T) { t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) return } - if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { + if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) return } - p := make([]byte, len(expectedResponseLarge)) // Give time to server to begin sending before client starts reading. time.Sleep(2 * time.Second) - if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { + if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) { t.Errorf("s.Read(_) = _, %v, want _, ", err) return } - if _, err = s.Read(p); err != io.EOF { + if _, _, err = s.Read(math.MaxInt32); err != io.EOF { t.Errorf("Failed to complete the stream %v; want ", err) } }() @@ -1017,16 +1004,15 @@ func TestLargeMessageDelayWrite(t *testing.T) { // Give time to server to start reading before client starts sending. time.Sleep(2 * time.Second) - if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { + if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) return } - p := make([]byte, len(expectedResponseLarge)) - if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { + if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) { t.Errorf("io.ReadFull(%v) = _, %v, want %v, ", err, p, expectedResponse) return } - if _, err = s.Read(p); err != io.EOF { + if _, _, err = s.Read(math.MaxInt32); err != io.EOF { t.Errorf("Failed to complete the stream %v; want ", err) } }() @@ -1047,19 +1033,10 @@ func TestGracefulClose(t *testing.T) { t.Fatalf("NewStream(_, _) = _, %v, want _, ", err) } msg := make([]byte, 1024) - outgoingHeader := make([]byte, 5) - outgoingHeader[0] = byte(0) - binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg))) - incomingHeader := make([]byte, 5) - if err := ct.Write(s, outgoingHeader, msg, &Options{}); err != nil { + if err := ct.Write(s, msg, &Options{}); err != nil { t.Fatalf("Error while writing: %v", err) } - if _, err := s.Read(incomingHeader); err != nil { - t.Fatalf("Error while reading: %v", err) - } - sz := binary.BigEndian.Uint32(incomingHeader[1:]) - recvMsg := make([]byte, int(sz)) - if _, err := s.Read(recvMsg); err != nil { + if _, _, err := s.Read(math.MaxInt32); err != nil { t.Fatalf("Error while reading: %v", err) } if err = ct.GracefulClose(); err != nil { @@ -1075,14 +1052,14 @@ func TestGracefulClose(t *testing.T) { if err == errStreamDrain { return } - ct.Write(str, nil, nil, &Options{Last: true}) - if _, err := str.Read(make([]byte, 8)); err != errStreamDrain { + ct.Write(str, nil, &Options{Last: true}) + if _, _, err := str.Read(math.MaxInt32); err != errStreamDrain { t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, errStreamDrain) } }() } - ct.Write(s, nil, nil, &Options{Last: true}) - if _, err := s.Read(incomingHeader); err != io.EOF { + ct.Write(s, nil, &Options{Last: true}) + if _, _, err := s.Read(math.MaxInt32); err != io.EOF { t.Fatalf("Client expected EOF from the server. Got: %v", err) } // The stream which was created before graceful close can still proceed. @@ -1110,13 +1087,13 @@ func TestLargeMessageSuspension(t *testing.T) { }() // Write should not be done successfully due to flow control. msg := make([]byte, initialWindowSize*8) - ct.Write(s, nil, msg, &Options{}) - err = ct.Write(s, nil, msg, &Options{Last: true}) + ct.Write(s, msg, &Options{}) + err = ct.Write(s, msg, &Options{Last: true}) if err != errStreamDone { t.Fatalf("Write got %v, want io.EOF", err) } expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) - if _, err := s.Read(make([]byte, 8)); err != expectedErr { + if _, _, err := s.Read(math.MaxInt32); err != expectedErr { t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) } ct.Close() @@ -1305,7 +1282,7 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) { t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id) } // Exhaust client's connection window. - if err := st.Write(sstream1, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { + if err := st.Write(sstream1, make([]byte, defaultWindowSize), &Options{}); err != nil { t.Fatalf("Server failed to write data. Err: %v", err) } notifyChan = make(chan struct{}) @@ -1330,17 +1307,17 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) { t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id) } // Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream. - if err := st.Write(sstream2, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { + if err := st.Write(sstream2, make([]byte, defaultWindowSize), &Options{}); err != nil { t.Fatalf("Server failed to write data. Err: %v", err) } // Client should be able to read data on second stream. - if _, err := cstream2.Read(make([]byte, defaultWindowSize)); err != nil { + if _, _, err := cstream2.Read(math.MaxInt32); err != nil { t.Fatalf("_.Read(_) = _, %v, want _, ", err) } // Client should be able to read data on first stream. - if _, err := cstream1.Read(make([]byte, defaultWindowSize)); err != nil { + if _, _, err := cstream1.Read(math.MaxInt32); err != nil { t.Fatalf("_.Read(_) = _, %v, want _, ", err) } } @@ -1373,7 +1350,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { t.Fatalf("Failed to create 1st stream. Err: %v", err) } // Exhaust server's connection window. - if err := client.Write(cstream1, nil, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil { + if err := client.Write(cstream1, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil { t.Fatalf("Client failed to write data. Err: %v", err) } //Client should be able to create another stream and send data on it. @@ -1381,7 +1358,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { if err != nil { t.Fatalf("Failed to create 2nd stream. Err: %v", err) } - if err := client.Write(cstream2, nil, make([]byte, defaultWindowSize), &Options{}); err != nil { + if err := client.Write(cstream2, make([]byte, defaultWindowSize), &Options{}); err != nil { t.Fatalf("Client failed to write data. Err: %v", err) } // Get the streams on server. @@ -1403,11 +1380,11 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) { } st.mu.Unlock() // Reading from the stream on server should succeed. - if _, err := sstream1.Read(make([]byte, defaultWindowSize)); err != nil { + if _, _, err := sstream1.Read(math.MaxInt32); err != nil { t.Fatalf("_.Read(_) = %v, want ", err) } - if _, err := sstream1.Read(make([]byte, 1)); err != io.EOF { + if _, _, err := sstream1.Read(math.MaxInt32); err != io.EOF { t.Fatalf("_.Read(_) = %v, want io.EOF", err) } @@ -1616,11 +1593,10 @@ func TestEncodingRequiredStatus(t *testing.T) { Last: true, Delay: false, } - if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != errStreamDone { + if err := ct.Write(s, expectedRequest, &opts); err != nil && err != errStreamDone { t.Fatalf("Failed to write the request: %v", err) } - p := make([]byte, http2MaxFrameLen) - if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF { + if _, _, err := s.Read(math.MaxInt32); err != io.EOF { t.Fatalf("Read got error %v, want %v", err, io.EOF) } if !reflect.DeepEqual(s.Status(), encodingTestStatus) { @@ -1640,8 +1616,7 @@ func TestInvalidHeaderField(t *testing.T) { if err != nil { return } - p := make([]byte, http2MaxFrameLen) - _, err = s.trReader.(*transportReader).Read(p) + _, _, err = s.Read(math.MaxInt32) if se, ok := err.(StreamError); !ok || se.Code != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) { t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField) } @@ -1764,26 +1739,17 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) t.Fatalf("Failed to create stream. Err: %v", err) } msg := make([]byte, msgSize) - buf := make([]byte, msgSize+5) - buf[0] = byte(0) - binary.BigEndian.PutUint32(buf[1:], uint32(msgSize)) - copy(buf[5:], msg) opts := Options{} - header := make([]byte, 5) for i := 1; i <= 10; i++ { - if err := ct.Write(cstream, nil, buf, &opts); err != nil { + if err := ct.Write(cstream, msg, &opts); err != nil { t.Fatalf("Error on client while writing message: %v", err) } - if _, err := cstream.Read(header); err != nil { - t.Fatalf("Error on client while reading data frame header: %v", err) - } - sz := binary.BigEndian.Uint32(header[1:]) - recvMsg := make([]byte, int(sz)) - if _, err := cstream.Read(recvMsg); err != nil { + _, recvMsg, err := cstream.Read(math.MaxInt32) + if err != nil { t.Fatalf("Error on client while reading data: %v", err) } - if len(recvMsg) != len(msg) { - t.Fatalf("Length of message received by client: %v, want: %v", len(recvMsg), len(msg)) + if !bytes.Equal(recvMsg, msg) { + t.Fatalf("Message received by client(len: %d) not equal to what was expected(len: %d)", len(recvMsg), len(msg)) } } var sstream *Stream @@ -1794,8 +1760,8 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) st.mu.Unlock() loopyServerStream := st.loopy.estdStreams[sstream.id] loopyClientStream := ct.loopy.estdStreams[cstream.id] - ct.Write(cstream, nil, nil, &Options{Last: true}) // Close the stream. - if _, err := cstream.Read(header); err != io.EOF { + ct.Write(cstream, nil, &Options{Last: true}) // Close the stream. + if _, _, err := cstream.Read(math.MaxInt32); err != io.EOF { t.Fatalf("Client expected an EOF from the server. Got: %v", err) } // Sleep for a little to make sure both sides flush out their buffers. @@ -1816,11 +1782,11 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, ct.loopy.sendQuota) } // Check stream flow control. - if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding { - t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding) + if int(cstream.fc.limit)-int(cstream.fc.rcvd) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding { + t.Fatalf("Account mismatch: client stream inflow limit(%d) - rcvd(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.rcvd, st.loopy.oiws, loopyServerStream.bytesOutStanding) } - if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(ct.loopy.oiws)-loopyClientStream.bytesOutStanding { - t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, ct.loopy.oiws, loopyClientStream.bytesOutStanding) + if int(sstream.fc.limit)-int(sstream.fc.rcvd) != int(ct.loopy.oiws)-loopyClientStream.bytesOutStanding { + t.Fatalf("Account mismatch: server stream inflow limit(%d) - rcvd(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.rcvd, ct.loopy.oiws, loopyClientStream.bytesOutStanding) } } @@ -2000,8 +1966,7 @@ func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders) stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh) defer cleanUp() want := httpStatusConvTab[httpStatus] - buf := make([]byte, 8) - _, err := stream.Read(buf) + _, _, err := stream.Read(math.MaxInt32) if err == nil { t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want) } @@ -2017,8 +1982,7 @@ func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders) func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) { stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader) defer cleanUp() - buf := make([]byte, 8) - _, err := stream.Read(buf) + _, _, err := stream.Read(math.MaxInt32) if err != io.EOF { t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err) } @@ -2035,45 +1999,25 @@ func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) { // If any error occurs on a call to Stream.Read, future calls // should continue to return that same error. func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { - testRecvBuffer := newRecvBuffer() - s := &Stream{ - ctx: context.Background(), - buf: testRecvBuffer, - requestRead: func(int) {}, - } - s.trReader = &transportReader{ - reader: &recvBufferReader{ - ctx: s.ctx, - ctxDone: s.ctx.Done(), - recv: s.buf, - }, - windowHandler: func(int) {}, - } - testData := make([]byte, 1) - testData[0] = 5 + s := newStream(context.Background()) testErr := errors.New("test error") - s.write(recvMsg{data: testData, err: testErr}) + s.notifyErr(testErr) - inBuf := make([]byte, 1) - actualCount, actualErr := s.Read(inBuf) - if actualCount != 0 { - t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount) - } - if actualErr.Error() != testErr.Error() { - t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) + pf, inBuf, actualErr := s.Read(math.MaxInt32) + if pf != false || inBuf != nil || actualErr.Error() != testErr.Error() { + t.Errorf("%v, %v, %v := s.Read(_) differs; want false, , %v", pf, inBuf, actualErr, testErr) } - s.write(recvMsg{data: testData, err: nil}) - s.write(recvMsg{data: testData, err: errors.New("different error from first")}) + testData := make([]byte, 6) + testData[0] = byte(1) + binary.BigEndian.PutUint32(testData[1:], uint32(1)) + s.consume(testData, 0) + s.notifyErr(errors.New("different error from first")) for i := 0; i < 2; i++ { - inBuf := make([]byte, 1) - actualCount, actualErr := s.Read(inBuf) - if actualCount != 0 { - t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount) - } - if actualErr.Error() != testErr.Error() { - t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) + pf, inBuf, actualErr := s.Read(math.MaxInt32) + if pf != false || inBuf != nil || actualErr.Error() != testErr.Error() { + t.Errorf("%v, %v, %v := s.Read(_) differs; want false, , %v", pf, inBuf, actualErr, testErr) } } } @@ -2113,11 +2057,7 @@ func runPingPongTest(t *testing.T, msgSize int) { t.Fatalf("Failed to create stream. Err: %v", err) } msg := make([]byte, msgSize) - outgoingHeader := make([]byte, 5) - outgoingHeader[0] = byte(0) - binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize)) opts := &Options{} - incomingHeader := make([]byte, 5) done := make(chan struct{}) go func() { timer := time.NewTimer(time.Second * 5) @@ -2127,23 +2067,22 @@ func runPingPongTest(t *testing.T, msgSize int) { for { select { case <-done: - ct.Write(stream, nil, nil, &Options{Last: true}) - if _, err := stream.Read(incomingHeader); err != io.EOF { + ct.Write(stream, nil, &Options{Last: true}) + if _, _, err := stream.Read(math.MaxInt32); err != io.EOF { t.Fatalf("Client expected EOF from the server. Got: %v", err) } return default: - if err := ct.Write(stream, outgoingHeader, msg, opts); err != nil { + if err := ct.Write(stream, msg, opts); err != nil { t.Fatalf("Error on client while writing message. Err: %v", err) } - if _, err := stream.Read(incomingHeader); err != nil { - t.Fatalf("Error on client while reading data header. Err: %v", err) - } - sz := binary.BigEndian.Uint32(incomingHeader[1:]) - recvMsg := make([]byte, int(sz)) - if _, err := stream.Read(recvMsg); err != nil { + _, recvMsg, err := stream.Read(math.MaxInt32) + if err != nil { t.Fatalf("Error on client while reading data. Err: %v", err) } + if !bytes.Equal(recvMsg, msg) { + t.Fatalf("%v != %v", recvMsg, msg) + } } } }