mirror of https://github.com/grpc/grpc-go.git
acquire header mutex while copying trailers
This commit is contained in:
parent
0ebea3ebca
commit
2fe9a4f87c
|
@ -277,11 +277,13 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status
|
||||||
if err == nil { // transport has not been closed
|
if err == nil { // transport has not been closed
|
||||||
// Note: The trailer fields are compressed with hpack after this call returns.
|
// Note: The trailer fields are compressed with hpack after this call returns.
|
||||||
// No WireLength field is set here.
|
// No WireLength field is set here.
|
||||||
|
s.hdrMu.Lock()
|
||||||
for _, sh := range ht.stats {
|
for _, sh := range ht.stats {
|
||||||
sh.HandleRPC(s.Context(), &stats.OutTrailer{
|
sh.HandleRPC(s.Context(), &stats.OutTrailer{
|
||||||
Trailer: s.trailer.Copy(),
|
Trailer: s.trailer.Copy(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
s.hdrMu.Unlock()
|
||||||
}
|
}
|
||||||
ht.Close(errors.New("finished writing status"))
|
ht.Close(errors.New("finished writing status"))
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -35,6 +35,7 @@ import (
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/mem"
|
"google.golang.org/grpc/mem"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/stats"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/protoadapt"
|
"google.golang.org/protobuf/protoadapt"
|
||||||
|
@ -246,6 +247,22 @@ type handleStreamTest struct {
|
||||||
ht *serverHandlerTransport
|
ht *serverHandlerTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockStatsHandler struct{}
|
||||||
|
|
||||||
|
func (h *mockStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *mockStatsHandler) HandleRPC(context.Context, stats.RPCStats) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *mockStatsHandler) HandleConn(context.Context, stats.ConnStats) {
|
||||||
|
}
|
||||||
|
|
||||||
func newHandleStreamTest(t *testing.T) *handleStreamTest {
|
func newHandleStreamTest(t *testing.T) *handleStreamTest {
|
||||||
bodyr, bodyw := io.Pipe()
|
bodyr, bodyw := io.Pipe()
|
||||||
req := &http.Request{
|
req := &http.Request{
|
||||||
|
@ -260,7 +277,12 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
|
||||||
Body: bodyr,
|
Body: bodyr,
|
||||||
}
|
}
|
||||||
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
||||||
ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool())
|
// Add mock stats handlers to exercise the stats handler code path.
|
||||||
|
statsHandlers := make([]stats.Handler, 0, 5)
|
||||||
|
for range 5 {
|
||||||
|
statsHandlers = append(statsHandlers, &mockStatsHandler{})
|
||||||
|
}
|
||||||
|
ht, err := NewServerHandlerTransport(rw, req, statsHandlers, mem.DefaultBufferPool())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -485,6 +507,12 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
|
||||||
|
|
||||||
hst := newHandleStreamTest(t)
|
hst := newHandleStreamTest(t)
|
||||||
handleStream := func(s *ServerStream) {
|
handleStream := func(s *ServerStream) {
|
||||||
|
if err := s.SendHeader(metadata.New(map[string]string{})); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
s.WriteStatus(st)
|
s.WriteStatus(st)
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
@ -501,6 +529,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
|
||||||
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
|
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
|
||||||
"Grpc-Message": {encodeGrpcMessage(msg)},
|
"Grpc-Message": {encodeGrpcMessage(msg)},
|
||||||
"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
|
"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
|
||||||
|
"Custom-Trailer": []string{"Custom trailer value"},
|
||||||
}
|
}
|
||||||
|
|
||||||
checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
|
checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
|
||||||
|
|
Loading…
Reference in New Issue