transport: refactor to split ClientStream from ServerStream from common Stream functionality (#7802)

This commit is contained in:
Doug Fawley 2024-11-04 13:42:38 -08:00 committed by GitHub
parent 70e8931a0e
commit 2a18bfcb16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 410 additions and 340 deletions

View File

@ -0,0 +1,115 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"sync/atomic"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// ClientStream implements streaming functionality for a gRPC client.
type ClientStream struct {
*Stream // Embed for common stream functionality.
ct ClientTransport
done chan struct{} // closed at the end of stream to unblock writers.
doneFunc func() // invoked at the end of stream.
headerChan chan struct{} // closed to indicate the end of header metadata.
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
// headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value).
headerValid bool
header metadata.MD // the received header metadata
noHeaders bool // set if the client never received headers (set only after the stream is done).
bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
status *status.Status // the status error received from the server
}
// BytesReceived indicates whether any bytes have been received on this stream.
func (s *ClientStream) BytesReceived() bool {
return atomic.LoadUint32(&s.bytesReceived) == 1
}
// Unprocessed indicates whether the server did not process this stream --
// i.e. it sent a refused stream or GOAWAY including this stream ID.
func (s *ClientStream) Unprocessed() bool {
return atomic.LoadUint32(&s.unprocessed) == 1
}
func (s *ClientStream) waitOnHeader() {
select {
case <-s.ctx.Done():
// Close the stream to prevent headers/trailers from changing after
// this function returns.
s.ct.CloseStream(s, ContextErr(s.ctx.Err()))
// headerChan could possibly not be closed yet if closeStream raced
// with operateHeaders; wait until it is closed explicitly here.
<-s.headerChan
case <-s.headerChan:
}
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *ClientStream) RecvCompress() string {
s.waitOnHeader()
return s.recvCompress
}
// Done returns a channel which is closed when it receives the final status
// from the server.
func (s *ClientStream) Done() <-chan struct{} {
return s.done
}
// Header returns the header metadata of the stream. Acquires the key-value
// pairs of header metadata once it is available. It blocks until i) the
// metadata is ready or ii) there is no header metadata or iii) the stream is
// canceled/expired.
func (s *ClientStream) Header() (metadata.MD, error) {
s.waitOnHeader()
if !s.headerValid || s.noHeaders {
return nil, s.status.Err()
}
return s.header.Copy(), nil
}
// TrailersOnly blocks until a header or trailers-only frame is received and
// then returns true if the stream was trailers-only. If the stream ends
// before headers are received, returns true, nil.
func (s *ClientStream) TrailersOnly() bool {
s.waitOnHeader()
return s.noHeaders
}
// Status returns the status received from the server.
// Status can be read safely only after the stream has ended,
// that is, after Done() is closed.
func (s *ClientStream) Status() *status.Status {
return s.status
}

View File

@ -225,7 +225,7 @@ func (ht *serverHandlerTransport) do(fn func()) error {
}
}
func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error {
func (ht *serverHandlerTransport) WriteStatus(s *ServerStream, st *status.Status) error {
ht.writeStatusMu.Lock()
defer ht.writeStatusMu.Unlock()
@ -289,14 +289,14 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
// writePendingHeaders sets common and custom headers on the first
// write call (Write, WriteHeader, or WriteStatus)
func (ht *serverHandlerTransport) writePendingHeaders(s *Stream) {
func (ht *serverHandlerTransport) writePendingHeaders(s *ServerStream) {
ht.writeCommonHeaders(s)
ht.writeCustomHeaders(s)
}
// writeCommonHeaders sets common headers on the first write
// call (Write, WriteHeader, or WriteStatus).
func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
func (ht *serverHandlerTransport) writeCommonHeaders(s *ServerStream) {
h := ht.rw.Header()
h["Date"] = nil // suppress Date to make tests happy; TODO: restore
h.Set("Content-Type", ht.contentType)
@ -317,7 +317,7 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
// writeCustomHeaders sets custom headers set on the stream via SetHeader
// on the first write call (Write, WriteHeader, or WriteStatus)
func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
func (ht *serverHandlerTransport) writeCustomHeaders(s *ServerStream) {
h := ht.rw.Header()
s.hdrMu.Lock()
@ -333,7 +333,7 @@ func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
s.hdrMu.Unlock()
}
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data mem.BufferSlice, _ *Options) error {
func (ht *serverHandlerTransport) Write(s *ServerStream, hdr []byte, data mem.BufferSlice, _ *Options) error {
// Always take a reference because otherwise there is no guarantee the data will
// be available after this function returns. This is what callers to Write
// expect.
@ -357,7 +357,7 @@ func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data mem.BufferSl
return nil
}
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
func (ht *serverHandlerTransport) WriteHeader(s *ServerStream, md metadata.MD) error {
if err := s.SetHeader(md); err != nil {
return err
}
@ -385,7 +385,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return err
}
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) {
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) {
// With this transport type there will be exactly 1 stream: this HTTP request.
var cancel context.CancelFunc
if ht.timeoutSet {
@ -408,16 +408,18 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
req := ht.req
s := &Stream{
id: 0, // irrelevant
ctx: ctx,
requestRead: func(int) {},
s := &ServerStream{
Stream: &Stream{
id: 0, // irrelevant
ctx: ctx,
requestRead: func(int) {},
buf: newRecvBuffer(),
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
}
s.trReader = &transportReader{

View File

@ -274,7 +274,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
st := newHandleStreamTest(t)
handleStream := func(s *Stream) {
handleStream := func(s *ServerStream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
@ -313,7 +313,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
st.ht.WriteStatus(s, status.New(codes.OK, ""))
}
st.ht.HandleStreams(
context.Background(), func(s *Stream) { go handleStream(s) },
context.Background(), func(s *ServerStream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
@ -342,11 +342,11 @@ func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
st := newHandleStreamTest(t)
handleStream := func(s *Stream) {
handleStream := func(s *ServerStream) {
st.ht.WriteStatus(s, status.New(statusCode, msg))
}
st.ht.HandleStreams(
context.Background(), func(s *Stream) { go handleStream(s) },
context.Background(), func(s *ServerStream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
@ -379,7 +379,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
if err != nil {
t.Fatal(err)
}
runStream := func(s *Stream) {
runStream := func(s *ServerStream) {
defer bodyw.Close()
select {
case <-s.ctx.Done():
@ -395,7 +395,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
}
ht.HandleStreams(
context.Background(), func(s *Stream) { go runStream(s) },
context.Background(), func(s *ServerStream) { go runStream(s) },
)
wantHeader := http.Header{
"Date": nil,
@ -412,7 +412,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
// TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
// concurrent "WriteStatus"s do not panic writing to closed "writes" channel.
func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *ServerStream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
@ -433,7 +433,7 @@ func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
// TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write"
// following "WriteStatus" does not panic writing to closed "writes" channel.
func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *ServerStream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
@ -444,10 +444,10 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
})
}
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) {
st := newHandleStreamTest(t)
st.ht.HandleStreams(
context.Background(), func(s *Stream) { go handleStream(st, s) },
context.Background(), func(s *ServerStream) { go handleStream(st, s) },
)
}
@ -476,11 +476,11 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
}
hst := newHandleStreamTest(t)
handleStream := func(s *Stream) {
handleStream := func(s *ServerStream) {
hst.ht.WriteStatus(s, st)
}
hst.ht.HandleStreams(
context.Background(), func(s *Stream) { go handleStream(s) },
context.Background(), func(s *ServerStream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,

View File

@ -123,7 +123,7 @@ type http2Client struct {
mu sync.Mutex // guard the following variables
nextID uint32
state transportState
activeStreams map[uint32]*Stream
activeStreams map[uint32]*ClientStream
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32
// goAwayReason records the http2.ErrCode and debug data received with the
@ -339,7 +339,7 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize),
fc: &trInFlow{limit: uint32(icwz)},
scheme: scheme,
activeStreams: make(map[uint32]*Stream),
activeStreams: make(map[uint32]*ClientStream),
isSecure: isSecure,
perRPCCreds: perRPCCreds,
kp: kp,
@ -480,17 +480,19 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
return t, nil
}
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *ClientStream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
ct: t,
done: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
headerChan: make(chan struct{}),
contentSubtype: callHdr.ContentSubtype,
doneFunc: callHdr.DoneFunc,
s := &ClientStream{
Stream: &Stream{
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
contentSubtype: callHdr.ContentSubtype,
},
ct: t,
done: make(chan struct{}),
headerChan: make(chan struct{}),
doneFunc: callHdr.DoneFunc,
}
s.wq = newWriteQuota(defaultWriteQuota, s.done)
s.requestRead = func(n int) {
@ -738,7 +740,7 @@ func (e NewStreamError) Error() string {
// NewStream creates a stream and registers it into the transport as "active"
// streams. All non-nil errors returned will be *NewStreamError.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) {
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error) {
ctx = peer.NewContext(ctx, t.getPeer())
// ServerName field of the resolver returned address takes precedence over
@ -910,7 +912,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream,
// CloseStream clears the footprint of a stream when the stream is not needed any more.
// This must not be executed in reader's goroutine.
func (t *http2Client) CloseStream(s *Stream, err error) {
func (t *http2Client) CloseStream(s *ClientStream, err error) {
var (
rst bool
rstCode http2.ErrCode
@ -922,7 +924,7 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
t.closeStream(s, err, rst, rstCode, status.Convert(err), nil, false)
}
func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) {
func (t *http2Client) closeStream(s *ClientStream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) {
// Set stream status to done.
if s.swapState(streamDone) == streamDone {
// If it was already done, return. If multiple closeStream calls
@ -1085,7 +1087,7 @@ func (t *http2Client) GracefulClose() {
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller
// should proceed only if Write returns nil.
func (t *http2Client) Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error {
func (t *http2Client) Write(s *ClientStream, hdr []byte, data mem.BufferSlice, opts *Options) error {
reader := data.Reader()
if opts.Last {
@ -1117,7 +1119,7 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *O
return nil
}
func (t *http2Client) getStream(f http2.Frame) *Stream {
func (t *http2Client) getStream(f http2.Frame) *ClientStream {
t.mu.Lock()
s := t.activeStreams[f.Header().StreamID]
t.mu.Unlock()
@ -1127,7 +1129,7 @@ func (t *http2Client) getStream(f http2.Frame) *Stream {
// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Client) adjustWindow(s *Stream, n uint32) {
func (t *http2Client) adjustWindow(s *ClientStream, n uint32) {
if w := s.fc.maybeAdjust(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
@ -1136,7 +1138,7 @@ func (t *http2Client) adjustWindow(s *Stream, n uint32) {
// updateWindow adjusts the inbound quota for the stream.
// Window updates will be sent out when the cumulative quota
// exceeds the corresponding threshold.
func (t *http2Client) updateWindow(s *Stream, n uint32) {
func (t *http2Client) updateWindow(s *ClientStream, n uint32) {
if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
@ -1383,7 +1385,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) error {
return connectionErrorf(true, nil, "received goaway and there are no active streams")
}
streamsToClose := make([]*Stream, 0)
streamsToClose := make([]*ClientStream, 0)
for streamID, stream := range t.activeStreams {
if streamID > id && streamID <= upperLimit {
// The stream was unprocessed by the server.

View File

@ -111,7 +111,7 @@ type http2Server struct {
// already initialized since draining is already underway.
drainEvent *grpcsync.Event
state transportState
activeStreams map[uint32]*Stream
activeStreams map[uint32]*ServerStream
// idle is the time instant when the connection went idle.
// This is either the beginning of the connection or when the number of
// RPCs go down to 0.
@ -256,7 +256,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
inTapHandle: config.InTapHandle,
fc: &trInFlow{limit: uint32(icwz)},
state: reachable,
activeStreams: make(map[uint32]*Stream),
activeStreams: make(map[uint32]*ServerStream),
stats: config.StatsHandlers,
kp: kp,
idle: time.Now(),
@ -359,7 +359,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
// operateHeaders takes action on the decoded headers. Returns an error if fatal
// error encountered and transport needs to close, otherwise returns nil.
func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*ServerStream)) error {
// Acquire max stream ID lock for entire duration
t.maxStreamMu.Lock()
defer t.maxStreamMu.Unlock()
@ -385,11 +385,13 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
t.maxStreamID = streamID
buf := newRecvBuffer()
s := &Stream{
id: streamID,
s := &ServerStream{
Stream: &Stream{
id: streamID,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
},
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
headerWireLength: int(frame.Header().Length),
}
var (
@ -634,7 +636,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
// traceCtx attaches trace to ctx and returns the new context.
func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
func (t *http2Server) HandleStreams(ctx context.Context, handle func(*ServerStream)) {
defer func() {
close(t.readerDone)
<-t.loopyWriterDone
@ -698,7 +700,7 @@ func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
}
}
func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
func (t *http2Server) getStream(f http2.Frame) (*ServerStream, bool) {
t.mu.Lock()
defer t.mu.Unlock()
if t.activeStreams == nil {
@ -716,7 +718,7 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Server) adjustWindow(s *Stream, n uint32) {
func (t *http2Server) adjustWindow(s *ServerStream, n uint32) {
if w := s.fc.maybeAdjust(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
@ -726,7 +728,7 @@ func (t *http2Server) adjustWindow(s *Stream, n uint32) {
// updateWindow adjusts the inbound quota for the stream and the transport.
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.
func (t *http2Server) updateWindow(s *Stream, n uint32) {
func (t *http2Server) updateWindow(s *ServerStream, n uint32) {
if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id,
increment: w,
@ -963,7 +965,7 @@ func (t *http2Server) checkForHeaderListSize(it any) bool {
return true
}
func (t *http2Server) streamContextErr(s *Stream) error {
func (t *http2Server) streamContextErr(s *ServerStream) error {
select {
case <-t.done:
return ErrConnClosing
@ -973,7 +975,7 @@ func (t *http2Server) streamContextErr(s *Stream) error {
}
// WriteHeader sends the header metadata md back to the client.
func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
func (t *http2Server) WriteHeader(s *ServerStream, md metadata.MD) error {
s.hdrMu.Lock()
defer s.hdrMu.Unlock()
if s.getState() == streamDone {
@ -1006,7 +1008,7 @@ func (t *http2Server) setResetPingStrikes() {
atomic.StoreUint32(&t.resetPingStrikes, 1)
}
func (t *http2Server) writeHeaderLocked(s *Stream) error {
func (t *http2Server) writeHeaderLocked(s *ServerStream) error {
// TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
// first and create a slice of that exact size.
headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
@ -1046,7 +1048,7 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error {
// There is no further I/O operations being able to perform on this stream.
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
func (t *http2Server) WriteStatus(s *ServerStream, st *status.Status) error {
s.hdrMu.Lock()
defer s.hdrMu.Unlock()
@ -1117,7 +1119,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, hdr []byte, data mem.BufferSlice, _ *Options) error {
func (t *http2Server) Write(s *ServerStream, hdr []byte, data mem.BufferSlice, _ *Options) error {
reader := data.Reader()
if !s.isHeaderSent() { // Headers haven't been written yet.
@ -1276,7 +1278,7 @@ func (t *http2Server) Close(err error) {
}
// deleteStream deletes the stream s from transport's active streams.
func (t *http2Server) deleteStream(s *Stream, eosReceived bool) {
func (t *http2Server) deleteStream(s *ServerStream, eosReceived bool) {
t.mu.Lock()
if _, ok := t.activeStreams[s.id]; ok {
@ -1297,7 +1299,7 @@ func (t *http2Server) deleteStream(s *Stream, eosReceived bool) {
}
// finishStream closes the stream and puts the trailing headerFrame into controlbuf.
func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) {
func (t *http2Server) finishStream(s *ServerStream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) {
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines.
@ -1321,7 +1323,7 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h
}
// closeStream clears the footprint of a stream when the stream is not needed any more.
func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) {
func (t *http2Server) closeStream(s *ServerStream, rst bool, rstCode http2.ErrCode, eosReceived bool) {
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines.

View File

@ -0,0 +1,158 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"context"
"errors"
"strings"
"sync"
"sync/atomic"
"google.golang.org/grpc/metadata"
)
// ServerStream implements streaming functionality for a gRPC server.
type ServerStream struct {
*Stream // Embed for common stream functionality.
st ServerTransport
ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance)
cancel context.CancelFunc // invoked at the end of stream to cancel ctx.
// Holds compressor names passed in grpc-accept-encoding metadata from the
// client.
clientAdvertisedCompressors string
headerWireLength int
// hdrMu protects outgoing header and trailer metadata.
hdrMu sync.Mutex
header metadata.MD // the outgoing header metadata. Updated by WriteHeader.
headerSent uint32 // atomically set to 1 when the headers are sent out.
}
// isHeaderSent indicates whether headers have been sent.
func (s *ServerStream) isHeaderSent() bool {
return atomic.LoadUint32(&s.headerSent) == 1
}
// updateHeaderSent updates headerSent and returns true
// if it was already set.
func (s *ServerStream) updateHeaderSent() bool {
return atomic.SwapUint32(&s.headerSent, 1) == 1
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *ServerStream) RecvCompress() string {
return s.recvCompress
}
// SendCompress returns the send compressor name.
func (s *ServerStream) SendCompress() string {
return s.sendCompress
}
// ContentSubtype returns the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". This will always be lowercase. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
func (s *ServerStream) ContentSubtype() string {
return s.contentSubtype
}
// SetSendCompress sets the compression algorithm to the stream.
func (s *ServerStream) SetSendCompress(name string) error {
if s.isHeaderSent() || s.getState() == streamDone {
return errors.New("transport: set send compressor called after headers sent or stream done")
}
s.sendCompress = name
return nil
}
// SetContext sets the context of the stream. This will be deleted once the
// stats handler callouts all move to gRPC layer.
func (s *ServerStream) SetContext(ctx context.Context) {
s.ctx = ctx
}
// ClientAdvertisedCompressors returns the compressor names advertised by the
// client via grpc-accept-encoding header.
func (s *ServerStream) ClientAdvertisedCompressors() []string {
values := strings.Split(s.clientAdvertisedCompressors, ",")
for i, v := range values {
values[i] = strings.TrimSpace(v)
}
return values
}
// Header returns the header metadata of the stream. It returns the out header
// after t.WriteHeader is called. It does not block and must not be called
// until after WriteHeader.
func (s *ServerStream) Header() (metadata.MD, error) {
// Return the header in stream. It will be the out
// header after t.WriteHeader is called.
return s.header.Copy(), nil
}
// HeaderWireLength returns the size of the headers of the stream as received
// from the wire.
func (s *ServerStream) HeaderWireLength() int {
return s.headerWireLength
}
// SetHeader sets the header metadata. This can be called multiple times.
// This should not be called in parallel to other data writes.
func (s *ServerStream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.isHeaderSent() || s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.header = metadata.Join(s.header, md)
s.hdrMu.Unlock()
return nil
}
// SendHeader sends the given header metadata. The given metadata is
// combined with any metadata set by previous calls to SetHeader and
// then written to the transport stream.
func (s *ServerStream) SendHeader(md metadata.MD) error {
return s.st.WriteHeader(s, md)
}
// SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can be called multiple times.
// This should not be called parallel to other data writes.
func (s *ServerStream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.trailer = metadata.Join(s.trailer, md)
s.hdrMu.Unlock()
return nil
}

View File

@ -27,7 +27,6 @@ import (
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
@ -287,14 +286,8 @@ const (
// Stream represents an RPC in the transport layer.
type Stream struct {
id uint32
st ServerTransport // nil for client side Stream
ct ClientTransport // nil for server side Stream
ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed at the end of stream to unblock writers. On the client side.
doneFunc func() // invoked at the end of stream on client side.
ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance)
method string // the associated RPC method of the stream
ctx context.Context // the associated context of the stream
method string // the associated RPC method of the stream
recvCompress string
sendCompress string
buf *recvBuffer
@ -302,58 +295,17 @@ type Stream struct {
fc *inFlow
wq *writeQuota
// Holds compressor names passed in grpc-accept-encoding metadata from the
// client. This is empty for the client side stream.
clientAdvertisedCompressors string
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
headerChan chan struct{} // closed to indicate the end of header metadata.
headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
// headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value). Not valid on server side.
headerValid bool
headerWireLength int // Only set on server side.
// hdrMu protects header and trailer metadata on the server-side.
hdrMu sync.Mutex
// On client side, header keeps the received header metadata.
//
// On server side, header keeps the header set by SetHeader(). The complete
// header will merged into this after t.WriteHeader() is called.
header metadata.MD
trailer metadata.MD // the key-value map of trailer metadata.
noHeaders bool // set if the client never received headers (set only after the stream is done).
// On the server-side, headerSent is atomically set to 1 when the headers are sent out.
headerSent uint32
state streamState
// On client-side it is the status error received from the server.
// On server-side it is unused.
status *status.Status
bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
// contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined.
contentSubtype string
}
// isHeaderSent is only valid on the server-side.
func (s *Stream) isHeaderSent() bool {
return atomic.LoadUint32(&s.headerSent) == 1
}
// updateHeaderSent updates headerSent and returns true
// if it was already set. It is valid only on server-side.
func (s *Stream) updateHeaderSent() bool {
return atomic.SwapUint32(&s.headerSent, 1) == 1
trailer metadata.MD // the key-value map of trailer metadata.
}
func (s *Stream) swapState(st streamState) streamState {
@ -368,110 +320,12 @@ func (s *Stream) getState() streamState {
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
}
func (s *Stream) waitOnHeader() {
if s.headerChan == nil {
// On the server headerChan is always nil since a stream originates
// only after having received headers.
return
}
select {
case <-s.ctx.Done():
// Close the stream to prevent headers/trailers from changing after
// this function returns.
s.ct.CloseStream(s, ContextErr(s.ctx.Err()))
// headerChan could possibly not be closed yet if closeStream raced
// with operateHeaders; wait until it is closed explicitly here.
<-s.headerChan
case <-s.headerChan:
}
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *Stream) RecvCompress() string {
s.waitOnHeader()
return s.recvCompress
}
// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(name string) error {
if s.isHeaderSent() || s.getState() == streamDone {
return errors.New("transport: set send compressor called after headers sent or stream done")
}
s.sendCompress = name
return nil
}
// SendCompress returns the send compressor name.
func (s *Stream) SendCompress() string {
return s.sendCompress
}
// ClientAdvertisedCompressors returns the compressor names advertised by the
// client via grpc-accept-encoding header.
func (s *Stream) ClientAdvertisedCompressors() []string {
values := strings.Split(s.clientAdvertisedCompressors, ",")
for i, v := range values {
values[i] = strings.TrimSpace(v)
}
return values
}
// Done returns a channel which is closed when it receives the final status
// from the server.
func (s *Stream) Done() <-chan struct{} {
return s.done
}
// Header returns the header metadata of the stream.
//
// On client side, it acquires the key-value pairs of header metadata once it is
// available. It blocks until i) the metadata is ready or ii) there is no header
// metadata or iii) the stream is canceled/expired.
//
// On server side, it returns the out header after t.WriteHeader is called. It
// does not block and must not be called until after WriteHeader.
func (s *Stream) Header() (metadata.MD, error) {
if s.headerChan == nil {
// On server side, return the header in stream. It will be the out
// header after t.WriteHeader is called.
return s.header.Copy(), nil
}
s.waitOnHeader()
if !s.headerValid || s.noHeaders {
return nil, s.status.Err()
}
return s.header.Copy(), nil
}
// TrailersOnly blocks until a header or trailers-only frame is received and
// then returns true if the stream was trailers-only. If the stream ends
// before headers are received, returns true, nil. Client-side only.
func (s *Stream) TrailersOnly() bool {
s.waitOnHeader()
return s.noHeaders
}
// Trailer returns the cached trailer metadata. Note that if it is not called
// after the entire stream is done, it could return an empty MD. Client
// side only.
// after the entire stream is done, it could return an empty MD.
// It can be safely read only after stream has ended that is either read
// or write have returned io.EOF.
func (s *Stream) Trailer() metadata.MD {
c := s.trailer.Copy()
return c
}
// ContentSubtype returns the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". This will always be lowercase. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
func (s *Stream) ContentSubtype() string {
return s.contentSubtype
return s.trailer.Copy()
}
// Context returns the context of the stream.
@ -479,69 +333,11 @@ func (s *Stream) Context() context.Context {
return s.ctx
}
// SetContext sets the context of the stream. This will be deleted once the
// stats handler callouts all move to gRPC layer.
func (s *Stream) SetContext(ctx context.Context) {
s.ctx = ctx
}
// Method returns the method for the stream.
func (s *Stream) Method() string {
return s.method
}
// Status returns the status received from the server.
// Status can be read safely only after the stream has ended,
// that is, after Done() is closed.
func (s *Stream) Status() *status.Status {
return s.status
}
// HeaderWireLength returns the size of the headers of the stream as received
// from the wire. Valid only on the server.
func (s *Stream) HeaderWireLength() int {
return s.headerWireLength
}
// SetHeader sets the header metadata. This can be called multiple times.
// Server side only.
// This should not be called in parallel to other data writes.
func (s *Stream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.isHeaderSent() || s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.header = metadata.Join(s.header, md)
s.hdrMu.Unlock()
return nil
}
// SendHeader sends the given header metadata. The given metadata is
// combined with any metadata set by previous calls to SetHeader and
// then written to the transport stream.
func (s *Stream) SendHeader(md metadata.MD) error {
return s.st.WriteHeader(s, md)
}
// SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can be called multiple times. Server side only.
// This should not be called parallel to other data writes.
func (s *Stream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.getState() == streamDone {
return ErrIllegalHeaderWrite
}
s.hdrMu.Lock()
s.trailer = metadata.Join(s.trailer, md)
s.hdrMu.Unlock()
return nil
}
func (s *Stream) write(m recvMsg) {
s.buf.put(m)
}
@ -638,17 +434,6 @@ func (t *transportReader) Read(n int) (mem.Buffer, error) {
return buf, nil
}
// BytesReceived indicates whether any bytes have been received on this stream.
func (s *Stream) BytesReceived() bool {
return atomic.LoadUint32(&s.bytesReceived) == 1
}
// Unprocessed indicates whether the server did not process this stream --
// i.e. it sent a refused stream or GOAWAY including this stream ID.
func (s *Stream) Unprocessed() bool {
return atomic.LoadUint32(&s.unprocessed) == 1
}
// GoString is implemented by Stream so context.String() won't
// race when printing %#v.
func (s *Stream) GoString() string {
@ -777,16 +562,16 @@ type ClientTransport interface {
// Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole.
Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error
Write(s *ClientStream, hdr []byte, data mem.BufferSlice, opts *Options) error
// NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error)
// CloseStream clears the footprint of a stream when the stream is
// not needed any more. The err indicates the error incurred when
// CloseStream is called. Must be called when a stream is finished
// unless the associated transport is closing.
CloseStream(stream *Stream, err error)
CloseStream(stream *ClientStream, err error)
// Error returns a channel that is closed when some I/O error
// happens. Typically the caller should have a goroutine to monitor
@ -821,19 +606,19 @@ type ClientTransport interface {
// Write methods for a given Stream will be called serially.
type ServerTransport interface {
// HandleStreams receives incoming streams using the given handler.
HandleStreams(context.Context, func(*Stream))
HandleStreams(context.Context, func(*ServerStream))
// WriteHeader sends the header metadata for the given stream.
// WriteHeader may not be called on all streams.
WriteHeader(s *Stream, md metadata.MD) error
WriteHeader(s *ServerStream, md metadata.MD) error
// Write sends the data for the given stream.
// Write may not be called on all streams.
Write(s *Stream, hdr []byte, data mem.BufferSlice, opts *Options) error
Write(s *ServerStream, hdr []byte, data mem.BufferSlice, opts *Options) error
// WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs.
WriteStatus(s *Stream, st *status.Status) error
WriteStatus(s *ServerStream, st *status.Status) error
// Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their

View File

@ -117,7 +117,7 @@ const (
pingpong
)
func (h *testStreamHandler) handleStreamAndNotify(*Stream) {
func (h *testStreamHandler) handleStreamAndNotify(*ServerStream) {
if h.notify == nil {
return
}
@ -130,7 +130,7 @@ func (h *testStreamHandler) handleStreamAndNotify(*Stream) {
}()
}
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
func (h *testStreamHandler) handleStream(t *testing.T, s *ServerStream) {
req := expectedRequest
resp := expectedResponse
if s.Method() == "foo.Large" {
@ -153,7 +153,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *ServerStream) {
header := make([]byte, 5)
for {
if _, err := s.readTo(header); err != nil {
@ -180,7 +180,7 @@ func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
}
}
func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *ServerStream) {
conn, ok := s.st.(*http2Server)
if !ok {
t.Errorf("Failed to convert %v to *http2Server", s.st)
@ -213,14 +213,14 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
}
}
func (h *testStreamHandler) handleStreamEncodingRequiredStatus(s *Stream) {
func (h *testStreamHandler) handleStreamEncodingRequiredStatus(s *ServerStream) {
// raw newline is not accepted by http2 framer so it must be encoded.
h.t.WriteStatus(s, encodingTestStatus)
// Drain any remaining buffers from the stream since it was closed early.
s.Read(math.MaxInt)
}
func (h *testStreamHandler) handleStreamInvalidHeaderField(s *Stream) {
func (h *testStreamHandler) handleStreamInvalidHeaderField(s *ServerStream) {
headerFields := []hpack.HeaderField{}
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
h.t.controlBuf.put(&headerFrame{
@ -234,7 +234,7 @@ func (h *testStreamHandler) handleStreamInvalidHeaderField(s *Stream) {
// stream-level flow control.
// This handler assumes dynamic flow control is turned off and assumes window
// sizes to be set to defaultWindowSize.
func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *ServerStream) {
req := expectedRequest
resp := expectedResponse
if s.Method() == "foo.Large" {
@ -385,17 +385,17 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
case notifyCall:
go transport.HandleStreams(context.Background(), h.handleStreamAndNotify)
case suspended:
go transport.HandleStreams(context.Background(), func(*Stream) {})
go transport.HandleStreams(context.Background(), func(*ServerStream) {})
case misbehaved:
go transport.HandleStreams(context.Background(), func(s *Stream) {
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
go h.handleStreamMisbehave(t, s)
})
case encodingRequiredStatus:
go transport.HandleStreams(context.Background(), func(s *Stream) {
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
go h.handleStreamEncodingRequiredStatus(s)
})
case invalidHeaderField:
go transport.HandleStreams(context.Background(), func(s *Stream) {
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
go h.handleStreamInvalidHeaderField(s)
})
case delayRead:
@ -404,15 +404,15 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.mu.Lock()
close(s.ready)
s.mu.Unlock()
go transport.HandleStreams(context.Background(), func(s *Stream) {
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
go h.handleStreamDelayRead(t, s)
})
case pingpong:
go transport.HandleStreams(context.Background(), func(s *Stream) {
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
go h.handleStreamPingPong(t, s)
})
default:
go transport.HandleStreams(context.Background(), func(s *Stream) {
go transport.HandleStreams(context.Background(), func(s *ServerStream) {
go h.handleStream(t, s)
})
}
@ -941,7 +941,7 @@ func (s) TestMaxStreams(t *testing.T) {
}
// Keep creating streams until one fails with deadline exceeded, marking the application
// of server settings on client.
slist := []*Stream{}
slist := []*ClientStream{}
pctx, cancel := context.WithCancel(context.Background())
defer cancel()
timer := time.NewTimer(time.Second * 10)
@ -1035,7 +1035,7 @@ func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
onEachWrite: func() {},
})
// Loop until the server side stream is created.
var ss *Stream
var ss *ServerStream
for {
time.Sleep(time.Second)
sc.mu.Lock()
@ -1095,7 +1095,7 @@ func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) {
}
<-notifyChan
var sstream1 *Stream
var sstream1 *ServerStream
// Access stream on the server.
st.mu.Lock()
for _, v := range st.activeStreams {
@ -1121,7 +1121,7 @@ func (s) TestClientConnDecoupledFromApplicationRead(t *testing.T) {
t.Fatalf("Client failed to create second stream. Err: %v", err)
}
<-notifyChan
var sstream2 *Stream
var sstream2 *ServerStream
st.mu.Lock()
for _, v := range st.activeStreams {
if v.id == cstream2.id {
@ -1200,7 +1200,7 @@ func (s) TestServerConnDecoupledFromApplicationRead(t *testing.T) {
}
return false, nil
})
var sstream1 *Stream
var sstream1 *ServerStream
st.mu.Lock()
for _, v := range st.activeStreams {
if v.id == 1 {
@ -1654,7 +1654,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
const numStreams = 5
clientStreams := make([]*Stream, numStreams)
clientStreams := make([]*ClientStream, numStreams)
for i := 0; i < numStreams; i++ {
var err error
clientStreams[i], err = client.NewStream(ctx, &CallHdr{})
@ -1666,7 +1666,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
// For each stream send pingpong messages to the server.
for _, stream := range clientStreams {
wg.Add(1)
go func(stream *Stream) {
go func(stream *ClientStream) {
defer wg.Done()
buf := make([]byte, msgSize+5)
buf[0] = byte(0)
@ -1697,7 +1697,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
}(stream)
}
wg.Wait()
serverStreams := map[uint32]*Stream{}
serverStreams := map[uint32]*ServerStream{}
loopyClientStreams := map[uint32]*outStream{}
loopyServerStreams := map[uint32]*outStream{}
// Get all the streams from server reader and writer and client writer.
@ -2211,7 +2211,7 @@ func (s) TestWriteHeaderConnectionError(t *testing.T) {
}
<-notifyChan // Wait for server stream to be established.
var sstream *Stream
var sstream *ServerStream
// Access stream on the server.
serverTransport.mu.Lock()
for _, v := range serverTransport.activeStreams {
@ -2512,21 +2512,23 @@ func (s) TestClientHandshakeInfoDialer(t *testing.T) {
}
func (s) TestClientDecodeHeaderStatusErr(t *testing.T) {
testStream := func() *Stream {
return &Stream{
testStream := func() *ClientStream {
return &ClientStream{
Stream: &Stream{
buf: &recvBuffer{
c: make(chan recvMsg),
mu: sync.Mutex{},
},
},
done: make(chan struct{}),
headerChan: make(chan struct{}),
buf: &recvBuffer{
c: make(chan recvMsg),
mu: sync.Mutex{},
},
}
}
testClient := func(ts *Stream) *http2Client {
testClient := func(ts *ClientStream) *http2Client {
return &http2Client{
mu: sync.Mutex{},
activeStreams: map[uint32]*Stream{
activeStreams: map[uint32]*ClientStream{
0: ts,
},
controlBuf: newControlBuffer(make(<-chan struct{})),

View File

@ -817,7 +817,7 @@ func (p *payloadInfo) free() {
// the buffer is no longer needed.
// TODO: Refactor this function to reduce the number of arguments.
// See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool,
) (out mem.BufferSlice, err error) {
pf, compressed, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
@ -908,10 +908,14 @@ func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMes
return out, out.Len(), nil
}
type recvCompressor interface {
RecvCompress() string
}
// For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) error {
data, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor, isServer)
if err != nil {
return err

View File

@ -621,8 +621,8 @@ func bufferPool(bufferPool mem.BufferPool) ServerOption {
// workload (assuming a QPS of a few thousand requests/sec).
const serverWorkerResetThreshold = 1 << 16
// serverWorker blocks on a *transport.Stream channel forever and waits for
// data to be fed by serveStreams. This allows multiple requests to be
// serverWorker blocks on a *transport.ServerStream channel forever and waits
// for data to be fed by serveStreams. This allows multiple requests to be
// processed by the same goroutine, removing the need for expensive stack
// re-allocations (see the runtime.morestack problem [1]).
//
@ -1020,7 +1020,7 @@ func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport,
}()
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
st.HandleStreams(ctx, func(stream *transport.Stream) {
st.HandleStreams(ctx, func(stream *transport.ServerStream) {
s.handlersWG.Add(1)
streamQuota.acquire()
f := func() {
@ -1136,7 +1136,7 @@ func (s *Server) incrCallsFailed() {
s.channelz.ServerMetrics.CallsFailed.Add(1)
}
func (s *Server) sendResponse(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, msg any, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
func (s *Server) sendResponse(ctx context.Context, t transport.ServerTransport, stream *transport.ServerStream, msg any, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
if err != nil {
channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err)
@ -1212,7 +1212,7 @@ func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info
}
}
func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) {
func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTransport, stream *transport.ServerStream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) {
shs := s.opts.statsHandlers
if len(shs) != 0 || trInfo != nil || channelz.IsOn() {
if channelz.IsOn() {
@ -1541,7 +1541,7 @@ func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, inf
}
}
func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, sd *StreamDesc, trInfo *traceInfo) (err error) {
func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTransport, stream *transport.ServerStream, info *serviceInfo, sd *StreamDesc, trInfo *traceInfo) (err error) {
if channelz.IsOn() {
s.incrCallsStarted()
}
@ -1738,7 +1738,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran
return t.WriteStatus(ss.s, statusOK)
}
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) {
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.ServerStream) {
ctx := stream.Context()
ctx = contextWithServer(ctx, s)
var ti *traceInfo
@ -2103,7 +2103,7 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetSendCompressor(ctx context.Context, name string) error {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream)
if !ok || stream == nil {
return fmt.Errorf("failed to fetch the stream from the given context")
}
@ -2125,7 +2125,7 @@ func SetSendCompressor(ctx context.Context, name string) error {
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func ClientSupportedCompressors(ctx context.Context) ([]string, error) {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream)
if !ok || stream == nil {
return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
}

View File

@ -164,13 +164,13 @@ func (s) TestRetryChainedInterceptor(t *testing.T) {
}
func (s) TestStreamContext(t *testing.T) {
expectedStream := &transport.Stream{}
expectedStream := &transport.ServerStream{}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = NewContextWithServerTransportStream(ctx, expectedStream)
s := ServerTransportStreamFromContext(ctx)
stream, ok := s.(*transport.Stream)
stream, ok := s.(*transport.ServerStream)
if !ok || expectedStream != stream {
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream)
}

View File

@ -584,7 +584,7 @@ type csAttempt struct {
ctx context.Context
cs *clientStream
t transport.ClientTransport
s *transport.Stream
s *transport.ClientStream
p *parser
pickResult balancer.PickResult
@ -1340,7 +1340,7 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin
}
type addrConnStream struct {
s *transport.Stream
s *transport.ClientStream
ac *addrConn
callHdr *transport.CallHdr
cancel context.CancelFunc
@ -1578,7 +1578,7 @@ type ServerStream interface {
type serverStream struct {
ctx context.Context
t transport.ServerTransport
s *transport.Stream
s *transport.ServerStream
p *parser
codec baseCodec