Cherry-pick #8519 to v1.75.x (#8530)

Original PR: #8519
Related issue: #8514

RELEASE NOTES:
* transport: Fix a data race while copying headers for stats handlers in
the std lib http2 server transport.
This commit is contained in:
Arjan Singh Bal 2025-08-21 22:25:36 +05:30 committed by GitHub
parent 7269d5fe70
commit 369c9aa6ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 99 additions and 7 deletions

View File

@ -277,11 +277,13 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status
if err == nil { // transport has not been closed if err == nil { // transport has not been closed
// Note: The trailer fields are compressed with hpack after this call returns. // Note: The trailer fields are compressed with hpack after this call returns.
// No WireLength field is set here. // No WireLength field is set here.
s.hdrMu.Lock()
for _, sh := range ht.stats { for _, sh := range ht.stats {
sh.HandleRPC(s.Context(), &stats.OutTrailer{ sh.HandleRPC(s.Context(), &stats.OutTrailer{
Trailer: s.trailer.Copy(), Trailer: s.trailer.Copy(),
}) })
} }
s.hdrMu.Unlock()
} }
ht.Close(errors.New("finished writing status")) ht.Close(errors.New("finished writing status"))
return err return err

View File

@ -35,6 +35,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/mem" "google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt" "google.golang.org/protobuf/protoadapt"
@ -246,7 +247,26 @@ type handleStreamTest struct {
ht *serverHandlerTransport ht *serverHandlerTransport
} }
func newHandleStreamTest(t *testing.T) *handleStreamTest { type mockStatsHandler struct {
rpcStatsCh chan stats.RPCStats
}
func (h *mockStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
return ctx
}
func (h *mockStatsHandler) HandleRPC(_ context.Context, s stats.RPCStats) {
h.rpcStatsCh <- s
}
func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
return ctx
}
func (h *mockStatsHandler) HandleConn(context.Context, stats.ConnStats) {
}
func newHandleStreamTest(t *testing.T, statsHandlers []stats.Handler) *handleStreamTest {
bodyr, bodyw := io.Pipe() bodyr, bodyw := io.Pipe()
req := &http.Request{ req := &http.Request{
ProtoMajor: 2, ProtoMajor: 2,
@ -260,7 +280,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
Body: bodyr, Body: bodyr,
} }
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool()) ht, err := NewServerHandlerTransport(rw, req, statsHandlers, mem.DefaultBufferPool())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -273,7 +293,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
} }
func (s) TestHandlerTransport_HandleStreams(t *testing.T) { func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
st := newHandleStreamTest(t) st := newHandleStreamTest(t, nil)
handleStream := func(s *ServerStream) { handleStream := func(s *ServerStream) {
if want := "/service/foo.bar"; s.method != want { if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want) t.Errorf("stream method = %q; want %q", s.method, want)
@ -342,7 +362,7 @@ func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
} }
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) { func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
st := newHandleStreamTest(t) st := newHandleStreamTest(t, nil)
handleStream := func(s *ServerStream) { handleStream := func(s *ServerStream) {
s.WriteStatus(status.New(statusCode, msg)) s.WriteStatus(status.New(statusCode, msg))
@ -451,7 +471,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
} }
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) { func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) {
st := newHandleStreamTest(t) st := newHandleStreamTest(t, nil)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
t.Cleanup(cancel) t.Cleanup(cancel)
st.ht.HandleStreams( st.ht.HandleStreams(
@ -483,7 +503,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
hst := newHandleStreamTest(t) hst := newHandleStreamTest(t, nil)
handleStream := func(s *ServerStream) { handleStream := func(s *ServerStream) {
s.WriteStatus(st) s.WriteStatus(st)
} }
@ -506,11 +526,81 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer) checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
} }
// Tests the use of stats handlers and ensures there are no data races while
// accessing trailers.
func (s) TestHandlerTransport_HandleStreams_StatsHandlers(t *testing.T) {
errDetails := []protoadapt.MessageV1{
&epb.RetryInfo{
RetryDelay: &durationpb.Duration{Seconds: 60},
},
&epb.ResourceInfo{
ResourceType: "foo bar",
ResourceName: "service.foo.bar",
Owner: "User",
},
}
statusCode := codes.ResourceExhausted
msg := "you are being throttled"
st, err := status.New(statusCode, msg).WithDetails(errDetails...)
if err != nil {
t.Fatal(err)
}
stBytes, err := proto.Marshal(st.Proto())
if err != nil {
t.Fatal(err)
}
// Add mock stats handlers to exercise the stats handler code path.
statsHandler := &mockStatsHandler{
rpcStatsCh: make(chan stats.RPCStats, 2),
}
hst := newHandleStreamTest(t, []stats.Handler{statsHandler})
handleStream := func(s *ServerStream) {
if err := s.SendHeader(metadata.New(map[string]string{})); err != nil {
t.Error(err)
}
if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
t.Error(err)
}
s.WriteStatus(st)
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
hst.ht.HandleStreams(
ctx, func(s *ServerStream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
}
wantTrailer := http.Header{
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
"Grpc-Message": {encodeGrpcMessage(msg)},
"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
"Custom-Trailer": []string{"Custom trailer value"},
}
checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
wantStatTypes := []stats.RPCStats{&stats.OutHeader{}, &stats.OutTrailer{}}
for _, wantType := range wantStatTypes {
select {
case <-ctx.Done():
t.Fatal("Context timed out waiting for statsHandler.HandleRPC() to be called.")
case s := <-statsHandler.rpcStatsCh:
if reflect.TypeOf(s) != reflect.TypeOf(wantType) {
t.Fatalf("Received RPCStats of type %T, want %T", s, wantType)
}
}
}
}
// TestHandlerTransport_Drain verifies that Drain() is not implemented // TestHandlerTransport_Drain verifies that Drain() is not implemented
// by `serverHandlerTransport`. // by `serverHandlerTransport`.
func (s) TestHandlerTransport_Drain(t *testing.T) { func (s) TestHandlerTransport_Drain(t *testing.T) {
defer func() { recover() }() defer func() { recover() }()
st := newHandleStreamTest(t) st := newHandleStreamTest(t, nil)
st.ht.Drain("whatever") st.ht.Drain("whatever")
t.Errorf("serverHandlerTransport.Drain() should have panicked") t.Errorf("serverHandlerTransport.Drain() should have panicked")
} }