transport: Propagate status code on receiving RST_STREAM during message read (#8289) (#8317)

This commit is contained in:
Arjan Singh Bal 2025-05-14 08:34:39 +05:30 committed by GitHub
parent f32eab3f63
commit 537fe8d2c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 81 additions and 3 deletions

View File

@ -1242,7 +1242,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
statusCode = codes.DeadlineExceeded
}
}
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false)
st := status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode)
t.closeStream(s, st.Err(), false, http2.ErrCodeNo, st, nil, false)
}
func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) {

View File

@ -919,8 +919,9 @@ func (s) TestLargeMessageSuspension(t *testing.T) {
}
// The server will send an RST stream frame on observing the deadline
// expiration making the client stream fail with a DeadlineExceeded status.
if _, err := s.readTo(make([]byte, 8)); err != io.EOF {
t.Fatalf("Read got unexpected error: %v, want %v", err, io.EOF)
_, err = s.readTo(make([]byte, 8))
if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
t.Fatalf("Read got unexpected error: %v, want status with code %v", err, codes.DeadlineExceeded)
}
if got, want := s.Status().Code(), codes.DeadlineExceeded; got != want {
t.Fatalf("Read got status %v with code %v, want %v", s.Status(), got, want)

View File

@ -19,16 +19,20 @@ package test
import (
"context"
"encoding/binary"
"io"
"net"
"sync"
"testing"
"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"
@ -153,3 +157,75 @@ func (s) TestClientTransportRestartsAfterStreamIDExhausted(t *testing.T) {
t.Fatal("Timeout expired when waiting for first client transport to close")
}
}
// Tests that an RST_STREAM frame that causes an io.ErrUnexpectedEOF while
// reading a gRPC message is correctly converted to a gRPC status with code
// CANCELLED. The test sends a data frame with a partial gRPC message, followed
// by an RST_STREAM frame with HTTP/2 code CANCELLED. The test asserts the
// client receives the correct status.
func (s) TestRSTDuringMessageRead(t *testing.T) {
lis, err := testutils.LocalTCPListener()
if err != nil {
t.Fatal(err)
}
defer lis.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("grpc.NewClient(%s) = %v", lis.Addr().String(), err)
}
defer cc.Close()
go func() {
conn, err := lis.Accept()
if err != nil {
t.Errorf("lis.Accept() = %v", err)
return
}
defer conn.Close()
framer := http2.NewFramer(conn, conn)
if _, err := io.ReadFull(conn, make([]byte, len(clientPreface))); err != nil {
t.Errorf("Error while reading client preface: %v", err)
return
}
if err := framer.WriteSettings(); err != nil {
t.Errorf("Error while writing settings: %v", err)
return
}
if err := framer.WriteSettingsAck(); err != nil {
t.Errorf("Error while writing settings: %v", err)
return
}
for ctx.Err() == nil {
frame, err := framer.ReadFrame()
if err != nil {
return
}
switch frame := frame.(type) {
case *http2.HeadersFrame:
// When the client creates a stream, write a partial gRPC
// message followed by an RST_STREAM.
const messageLen = 2048
buf := make([]byte, messageLen/2)
// Write the gRPC message length header.
binary.BigEndian.PutUint32(buf[1:5], uint32(messageLen))
if err := framer.WriteData(1, false, buf); err != nil {
return
}
framer.WriteRSTStream(1, http2.ErrCodeCancel)
default:
t.Logf("Server received frame: %v", frame)
}
}
}()
// The server will send a partial gRPC message before cancelling the stream.
// The client should get a gRPC status with code CANCELLED.
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Canceled {
t.Fatalf("client.EmptyCall() returned %v; want status with code %v", err, codes.Canceled)
}
}