acquire header mutex while copying trailers

This commit is contained in:
Arjan Bal 2025-08-18 16:21:26 +05:30
parent 0ebea3ebca
commit 2fe9a4f87c
2 changed files with 32 additions and 1 deletions

View File

@ -277,11 +277,13 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status
if err == nil { // transport has not been closed if err == nil { // transport has not been closed
// Note: The trailer fields are compressed with hpack after this call returns. // Note: The trailer fields are compressed with hpack after this call returns.
// No WireLength field is set here. // No WireLength field is set here.
s.hdrMu.Lock()
for _, sh := range ht.stats { for _, sh := range ht.stats {
sh.HandleRPC(s.Context(), &stats.OutTrailer{ sh.HandleRPC(s.Context(), &stats.OutTrailer{
Trailer: s.trailer.Copy(), Trailer: s.trailer.Copy(),
}) })
} }
s.hdrMu.Unlock()
} }
ht.Close(errors.New("finished writing status")) ht.Close(errors.New("finished writing status"))
return err return err

View File

@ -35,6 +35,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/mem" "google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt" "google.golang.org/protobuf/protoadapt"
@ -246,6 +247,22 @@ type handleStreamTest struct {
ht *serverHandlerTransport ht *serverHandlerTransport
} }
type mockStatsHandler struct{}
func (h *mockStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
return ctx
}
func (h *mockStatsHandler) HandleRPC(context.Context, stats.RPCStats) {
}
func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
return ctx
}
func (h *mockStatsHandler) HandleConn(context.Context, stats.ConnStats) {
}
func newHandleStreamTest(t *testing.T) *handleStreamTest { func newHandleStreamTest(t *testing.T) *handleStreamTest {
bodyr, bodyw := io.Pipe() bodyr, bodyw := io.Pipe()
req := &http.Request{ req := &http.Request{
@ -260,7 +277,12 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
Body: bodyr, Body: bodyr,
} }
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool()) // Add mock stats handlers to exercise the stats handler code path.
statsHandlers := make([]stats.Handler, 0, 5)
for range 5 {
statsHandlers = append(statsHandlers, &mockStatsHandler{})
}
ht, err := NewServerHandlerTransport(rw, req, statsHandlers, mem.DefaultBufferPool())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -485,6 +507,12 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
hst := newHandleStreamTest(t) hst := newHandleStreamTest(t)
handleStream := func(s *ServerStream) { handleStream := func(s *ServerStream) {
if err := s.SendHeader(metadata.New(map[string]string{})); err != nil {
t.Error(err)
}
if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
t.Error(err)
}
s.WriteStatus(st) s.WriteStatus(st)
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
@ -501,6 +529,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
"Grpc-Status": {fmt.Sprint(uint32(statusCode))}, "Grpc-Status": {fmt.Sprint(uint32(statusCode))},
"Grpc-Message": {encodeGrpcMessage(msg)}, "Grpc-Message": {encodeGrpcMessage(msg)},
"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)}, "Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
"Custom-Trailer": []string{"Custom trailer value"},
} }
checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer) checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)