From c195587d96d5ae30321b96a1e2e175fea09e9fda Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Wed, 10 Oct 2018 13:21:08 -0700 Subject: [PATCH] balancer: add trailer metadata to DoneInfo (#2359) --- balancer/balancer.go | 3 +++ internal/transport/http2_client.go | 8 +++++--- stream.go | 3 +++ test/balancer_test.go | 11 +++++++++-- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/balancer/balancer.go b/balancer/balancer.go index eb2231a4c..119c351f3 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -28,6 +28,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" ) @@ -160,6 +161,8 @@ type PickOptions struct { type DoneInfo struct { // Err is the rpc error the RPC finished with. It could be nil. Err error + // Trailer contains the metadata from the RPC's trailer, if present. + Trailer metadata.MD // BytesSent indicates if any bytes have been sent to the server. BytesSent bool // BytesReceived indicates if any byte has been received from the server. diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 3299fe9cb..2bd70597e 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -682,7 +682,9 @@ func (t *http2Client) CloseStream(s *Stream, err error) { func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) { // Set stream status to done. if s.swapState(streamDone) == streamDone { - // If it was already done, return. + // If it was already done, return. If multiple closeStream calls + // happen simultaneously, wait for the first to finish. + <-s.done return } // status and trailers can be updated here without any synchronization because the stream goroutine will @@ -696,8 +698,6 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // This will unblock reads eventually. s.write(recvMsg{err: err}) } - // This will unblock write. - close(s.done) // If headerChan isn't closed, then close it. if atomic.SwapUint32(&s.headerDone, 1) == 0 { s.noHeaders = true @@ -733,6 +733,8 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. return true } t.controlBuf.executeAndPut(addBackStreamQuota, cleanup) + // This will unblock write. + close(s.done) } // Close kicks off the shutdown process of the transport. This should be called diff --git a/stream.go b/stream.go index 492b650d5..b71eb3112 100644 --- a/stream.go +++ b/stream.go @@ -816,11 +816,14 @@ func (a *csAttempt) finish(err error) { if a.done != nil { br := false + var tr metadata.MD if a.s != nil { br = a.s.BytesReceived() + tr = a.s.Trailer() } a.done(balancer.DoneInfo{ Err: err, + Trailer: tr, BytesSent: a.s != nil, BytesReceived: br, }) diff --git a/test/balancer_test.go b/test/balancer_test.go index e74377353..ec36e602c 100644 --- a/test/balancer_test.go +++ b/test/balancer_test.go @@ -152,7 +152,7 @@ func testDoneInfo(t *testing.T, e env) { grpc.WithBalancerName(testBalancerName), } te.userAgent = failAppUA - te.startServer(&testServer{security: e.security, unaryCallSleepTime: time.Second}) + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -164,7 +164,14 @@ func testDoneInfo(t *testing.T, e env) { if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(err, wantErr) { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr) } - if len(b.doneInfo) != 1 || !reflect.DeepEqual(b.doneInfo[0].Err, wantErr) { + if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { + t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) + } + + if len(b.doneInfo) < 1 || !reflect.DeepEqual(b.doneInfo[0].Err, wantErr) { t.Fatalf("b.doneInfo = %v; want b.doneInfo[0].Err = %v", b.doneInfo, wantErr) } + if len(b.doneInfo) < 2 || !reflect.DeepEqual(b.doneInfo[1].Trailer, testTrailerMetadata) { + t.Fatalf("b.doneInfo = %v; want b.doneInfo[1].Trailer = %v", b.doneInfo, testTrailerMetadata) + } }