client: fix ClientStream.Header() behavior (#6557)

This commit is contained in:
Doug Fawley 2023-08-18 08:05:48 -07:00 committed by GitHub
parent 8a2c220594
commit fe1519ecf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 110 additions and 70 deletions

View File

@ -31,10 +31,12 @@ import (
"github.com/golang/protobuf/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/binarylog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
iblog "google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
@ -1059,3 +1061,39 @@ func (s) TestServerBinaryLogFullDuplexError(t *testing.T) {
t.Fatal(err)
}
}
// TestCanceledStatus ensures a server that responds with a Canceled status has
// its trailers logged appropriately and is not treated as a canceled RPC.
func (s) TestCanceledStatus(t *testing.T) {
defer testSink.clear()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
const statusMsgWant = "server returned Canceled"
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
grpc.SetTrailer(ctx, metadata.Pairs("key", "value"))
return nil, status.Error(codes.Canceled, statusMsgWant)
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Canceled {
t.Fatalf("Received unexpected error from UnaryCall: %v; want Canceled", err)
}
got := testSink.logEntries(true)
last := got[len(got)-1]
if last.Type != binlogpb.GrpcLogEntry_EVENT_TYPE_SERVER_TRAILER ||
last.GetTrailer().GetStatusCode() != uint32(codes.Canceled) ||
last.GetTrailer().GetStatusMessage() != statusMsgWant ||
len(last.GetTrailer().GetMetadata().GetEntry()) != 1 ||
last.GetTrailer().GetMetadata().GetEntry()[0].GetKey() != "key" ||
string(last.GetTrailer().GetMetadata().GetEntry()[0].GetValue()) != "value" {
t.Fatalf("Got binary log: %+v; want last entry is server trailing with status Canceled", got)
}
}

View File

@ -1505,14 +1505,15 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
return
}
isHeader := false
// If headerChan hasn't been closed yet
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
s.headerValid = true
if !endStream {
// HEADERS frame block carries a Response-Headers.
isHeader = true
// For headers, set them in s.header and close headerChan. For trailers or
// trailers-only, closeStream will set the trailers and close headerChan as
// needed.
if !endStream {
// If headerChan hasn't been closed yet (expected, given we checked it
// above, but something else could have potentially closed the whole
// stream).
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
s.headerValid = true
// These values can be set without any synchronization because
// stream goroutine will read it only after seeing a closed
// headerChan which we'll close after setting this.
@ -1520,15 +1521,12 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if len(mdata) > 0 {
s.header = mdata
}
} else {
// HEADERS frame block carries a Trailers-Only.
s.noHeaders = true
close(s.headerChan)
}
close(s.headerChan)
}
for _, sh := range t.statsHandlers {
if isHeader {
if !endStream {
inHeader := &stats.InHeader{
Client: true,
WireLength: int(frame.Header().Length),
@ -1554,9 +1552,10 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
statusGen = status.New(rawStatusCode, grpcMessage)
}
// if client received END_STREAM from server while stream was still active, send RST_STREAM
rst := s.getState() == streamActive
t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, statusGen, mdata, true)
// If client received END_STREAM from server while stream was still active,
// send RST_STREAM.
rstStream := s.getState() == streamActive
t.closeStream(s, io.EOF, rstStream, http2.ErrCodeNo, statusGen, mdata, true)
}
// readServerPreface reads and handles the initial settings frame from the

View File

@ -43,10 +43,6 @@ import (
"google.golang.org/grpc/tap"
)
// ErrNoHeaders is used as a signal that a trailers only response was received,
// and is not a real error.
var ErrNoHeaders = errors.New("stream has no headers")
const logLevel = 2
type bufferPool struct {
@ -390,14 +386,10 @@ func (s *Stream) Header() (metadata.MD, error) {
}
s.waitOnHeader()
if !s.headerValid {
if !s.headerValid || s.noHeaders {
return nil, s.status.Err()
}
if s.noHeaders {
return nil, ErrNoHeaders
}
return s.header.Copy(), nil
}

View File

@ -867,15 +867,18 @@ func Errorf(c codes.Code, format string, a ...any) error {
return status.Errorf(c, format, a...)
}
var errContextCanceled = status.Error(codes.Canceled, context.Canceled.Error())
var errContextDeadline = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())
// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
switch err {
case nil, io.EOF:
return err
case context.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
return errContextDeadline
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
return errContextCanceled
case io.ErrUnexpectedEOF:
return status.Error(codes.Internal, err.Error())
}

View File

@ -789,23 +789,23 @@ func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func())
func (cs *clientStream) Header() (metadata.MD, error) {
var m metadata.MD
noHeader := false
err := cs.withRetry(func(a *csAttempt) error {
var err error
m, err = a.s.Header()
if err == transport.ErrNoHeaders {
noHeader = true
return nil
}
return toRPCErr(err)
}, cs.commitAttemptLocked)
if m == nil && err == nil {
// The stream ended with success. Finish the clientStream.
err = io.EOF
}
if err != nil {
cs.finish(err)
return nil, err
}
if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged && !noHeader {
if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged && m != nil {
// Only log if binary log is on and header has not been logged, and
// there is actually headers to log.
logEntry := &binarylog.ServerHeader{
@ -821,6 +821,7 @@ func (cs *clientStream) Header() (metadata.MD, error) {
binlog.Log(cs.ctx, logEntry)
}
}
return m, nil
}
@ -929,24 +930,6 @@ func (cs *clientStream) RecvMsg(m any) error {
if err != nil || !cs.desc.ServerStreams {
// err != nil or non-server-streaming indicates end of stream.
cs.finish(err)
if len(cs.binlogs) != 0 {
// finish will not log Trailer. Log Trailer here.
logEntry := &binarylog.ServerTrailer{
OnClientSide: true,
Trailer: cs.Trailer(),
Err: err,
}
if logEntry.Err == io.EOF {
logEntry.Err = nil
}
if peer, ok := peer.FromContext(cs.Context()); ok {
logEntry.PeerAddr = peer.Addr
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, logEntry)
}
}
}
return err
}
@ -1002,18 +985,30 @@ func (cs *clientStream) finish(err error) {
}
}
}
cs.mu.Unlock()
// For binary logging. only log cancel in finish (could be caused by RPC ctx
// canceled or ClientConn closed). Trailer will be logged in RecvMsg.
//
// Only one of cancel or trailer needs to be logged. In the cases where
// users don't call RecvMsg, users must have already canceled the RPC.
if len(cs.binlogs) != 0 && status.Code(err) == codes.Canceled {
c := &binarylog.Cancel{
OnClientSide: true,
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, c)
// Only one of cancel or trailer needs to be logged.
if len(cs.binlogs) != 0 {
switch err {
case errContextCanceled, errContextDeadline, ErrClientConnClosing:
c := &binarylog.Cancel{
OnClientSide: true,
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, c)
}
default:
logEntry := &binarylog.ServerTrailer{
OnClientSide: true,
Trailer: cs.Trailer(),
Err: err,
}
if peer, ok := peer.FromContext(cs.Context()); ok {
logEntry.PeerAddr = peer.Addr
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, logEntry)
}
}
}
if err == nil {

View File

@ -6328,12 +6328,11 @@ func (s) TestGlobalBinaryLoggingOptions(t *testing.T) {
return &testpb.SimpleResponse{}, nil
},
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
for {
_, err := stream.Recv()
if err == io.EOF {
return nil
}
_, err := stream.Recv()
if err == io.EOF {
return nil
}
return status.Errorf(codes.Unknown, "expected client to call CloseSend")
},
}

View File

@ -211,6 +211,11 @@ func (s) TestRetryStreaming(t *testing.T) {
return nil
}
}
sHdr := func() serverOp {
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
return stream.SendHeader(metadata.Pairs("test_header", "test_value"))
}
}
sRes := func(b byte) serverOp {
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
msg := res(b)
@ -222,7 +227,7 @@ func (s) TestRetryStreaming(t *testing.T) {
}
sErr := func(c codes.Code) serverOp {
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
return status.New(c, "").Err()
return status.New(c, "this is a test error").Err()
}
}
sCloseSend := func() serverOp {
@ -270,7 +275,7 @@ func (s) TestRetryStreaming(t *testing.T) {
}
cErr := func(c codes.Code) clientOp {
return func(stream testgrpc.TestService_FullDuplexCallClient) error {
want := status.New(c, "").Err()
want := status.New(c, "this is a test error").Err()
if c == codes.OK {
want = io.EOF
}
@ -309,6 +314,11 @@ func (s) TestRetryStreaming(t *testing.T) {
cHdr := func() clientOp {
return func(stream testgrpc.TestService_FullDuplexCallClient) error {
_, err := stream.Header()
if err == io.EOF {
// The stream ended successfully; convert to nil to avoid
// erroring the test case.
err = nil
}
return err
}
}
@ -362,9 +372,13 @@ func (s) TestRetryStreaming(t *testing.T) {
sReq(1), sRes(3), sErr(codes.Unavailable),
},
clientOps: []clientOp{cReq(1), cRes(3), cErr(codes.Unavailable)},
}, {
desc: "Retry via ClientStream.Header()",
serverOps: []serverOp{sReq(1), sErr(codes.Unavailable), sReq(1), sAttempts(1)},
clientOps: []clientOp{cReq(1), cHdr() /* this should cause a retry */, cErr(codes.OK)},
}, {
desc: "No retry after header",
serverOps: []serverOp{sReq(1), sErr(codes.Unavailable)},
serverOps: []serverOp{sReq(1), sHdr(), sErr(codes.Unavailable)},
clientOps: []clientOp{cReq(1), cHdr(), cErr(codes.Unavailable)},
}, {
desc: "No retry after context",