diff --git a/call.go b/call.go index 9d3c82c55..2318bf0e9 100644 --- a/call.go +++ b/call.go @@ -37,26 +37,34 @@ import ( "io" "github.com/golang/protobuf/proto" + "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/transport" - "golang.org/x/net/context" ) // recv receives and parses an RPC response. // On error, it returns the error and indicates whether the call should be retried. // // TODO(zhaoq): Check whether the received message sequence is valid. -func recv(stream *transport.Stream, reply proto.Message) error { +func recv(t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply proto.Message) error { + // Try to acquire header metadata from the server if there is any. + var err error + c.headerMD, err = stream.Header() + if err != nil { + return err + } p := &parser{s: stream} for { - if err := recvProto(p, reply); err != nil { + if err = recvProto(p, reply); err != nil { if err == io.EOF { - return nil + break } return err } } + c.trailerMD = stream.Trailer() + return nil } // sendRPC writes out various information of an RPC such as Context and Message. @@ -145,17 +153,11 @@ func Invoke(ctx context.Context, method string, args, reply proto.Message, cc *C } return toRPCErr(err) } - // Try to acquire header metadata from the server if there is any. - c.headerMD, err = stream.Header() - if err != nil { - return toRPCErr(err) - } // Receive the response - lastErr = recv(stream, reply) + lastErr = recv(t, &c, stream, reply) if _, ok := lastErr.(transport.ConnectionError); ok { continue } - c.trailerMD = stream.Trailer() t.CloseStream(stream, lastErr) if lastErr != nil { return toRPCErr(lastErr) diff --git a/stream.go b/stream.go index 53baacac6..2a3be4ab0 100644 --- a/stream.go +++ b/stream.go @@ -37,10 +37,10 @@ import ( "io" "github.com/golang/protobuf/proto" + "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/transport" - "golang.org/x/net/context" ) // Stream defines the common interface a client or server stream has to satisfy. @@ -112,8 +112,14 @@ func (cs *clientStream) Context() context.Context { return cs.s.Context() } -func (cs *clientStream) Header() (md metadata.MD, err error) { - return cs.s.Header() +func (cs *clientStream) Header() (metadata.MD, error) { + m, err := cs.s.Header() + if err != nil { + if _, ok := err.(transport.ConnectionError); !ok { + cs.t.CloseStream(cs.s, err) + } + } + return m, err } func (cs *clientStream) Trailer() metadata.MD { @@ -142,6 +148,9 @@ func (cs *clientStream) RecvProto(m proto.Message) (err error) { if err == nil { return } + if _, ok := err.(transport.ConnectionError); !ok { + cs.t.CloseStream(cs.s, err) + } if err == io.EOF { if cs.s.StatusCode() == codes.OK { // Returns io.EOF to indicate the end of the stream. @@ -149,9 +158,6 @@ func (cs *clientStream) RecvProto(m proto.Message) (err error) { } return Errorf(cs.s.StatusCode(), cs.s.StatusDesc()) } - if _, ok := err.(transport.ConnectionError); !ok { - cs.t.CloseStream(cs.s, err) - } return toRPCErr(err) }