diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 430cd454b..babcaee50 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -362,6 +362,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, + closeStream: func(err error) { + t.CloseStream(s, err) + }, }, windowHandler: func(n int) { t.updateWindow(s, uint32(n)) diff --git a/internal/transport/transport.go b/internal/transport/transport.go index e70a46fd5..2580aa7d3 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -110,15 +110,15 @@ func (b *recvBuffer) get() <-chan recvMsg { return b.c } -// // recvBufferReader implements io.Reader interface to read the data from // recvBuffer. type recvBufferReader struct { - ctx context.Context - ctxDone <-chan struct{} // cache of ctx.Done() (for performance). - recv *recvBuffer - last []byte // Stores the remaining data in the previous calls. - err error + closeStream func(error) // Closes the client transport stream with the given error and nil trailer metadata. + ctx context.Context + ctxDone <-chan struct{} // cache of ctx.Done() (for performance). + recv *recvBuffer + last []byte // Stores the remaining data in the previous calls. + err error } // Read reads the next len(p) bytes from last. If last is drained, it tries to @@ -128,31 +128,53 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) { if r.err != nil { return 0, r.err } - n, r.err = r.read(p) - return n, r.err -} - -func (r *recvBufferReader) read(p []byte) (n int, err error) { if r.last != nil && len(r.last) > 0 { // Read remaining data left in last call. copied := copy(p, r.last) r.last = r.last[copied:] return copied, nil } + if r.closeStream != nil { + n, r.err = r.readClient(p) + } else { + n, r.err = r.read(p) + } + return n, r.err +} + +func (r *recvBufferReader) read(p []byte) (n int, err error) { select { case <-r.ctxDone: return 0, ContextErr(r.ctx.Err()) case m := <-r.recv.get(): - r.recv.load() - if m.err != nil { - return 0, m.err - } - copied := copy(p, m.data) - r.last = m.data[copied:] - return copied, nil + return r.readAdditional(m, p) } } +func (r *recvBufferReader) readClient(p []byte) (n int, err error) { + // If the context is canceled, then closes the stream with nil metadata. + // closeStream writes its error parameter to r.recv as a recvMsg. + // r.readAdditional acts on that message and returns the necessary error. + select { + case <-r.ctxDone: + r.closeStream(ContextErr(r.ctx.Err())) + m := <-r.recv.get() + return r.readAdditional(m, p) + case m := <-r.recv.get(): + return r.readAdditional(m, p) + } +} + +func (r *recvBufferReader) readAdditional(m recvMsg, p []byte) (n int, err error) { + r.recv.load() + if m.err != nil { + return 0, m.err + } + copied := copy(p, m.data) + r.last = m.data[copied:] + return copied, nil +} + type streamState uint32 const ( diff --git a/stream.go b/stream.go index 0c266d6f9..d06279a20 100644 --- a/stream.go +++ b/stream.go @@ -462,10 +462,7 @@ func (cs *clientStream) shouldRetry(err error) error { pushback := 0 hasPushback := false if cs.attempt.s != nil { - if to, toErr := cs.attempt.s.TrailersOnly(); toErr != nil { - // Context error; stop now. - return toErr - } else if !to { + if to, toErr := cs.attempt.s.TrailersOnly(); toErr != nil || !to { return err } diff --git a/test/context_canceled_test.go b/test/context_canceled_test.go new file mode 100644 index 000000000..19d7124c9 --- /dev/null +++ b/test/context_canceled_test.go @@ -0,0 +1,78 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "context" + "testing" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +func (s) TestContextCanceled(t *testing.T) { + ss := &stubServer{ + fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { + stream.SetTrailer(metadata.New(map[string]string{"a": "b"})) + return status.Error(codes.PermissionDenied, "perm denied") + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + var i, cntCanceled uint + cntPermDenied := func() uint { + return i - cntCanceled + } + for i, cntCanceled = 0, 0; i < 500 && (cntCanceled < 5 || cntPermDenied() < 5); i++ { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + str, err := ss.client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", ss.client, err) + } + // As this duration goes up chances of Recv returning Cancelled will decrease. + time.Sleep(time.Duration(i) * time.Microsecond) + cancel() + _, err = str.Recv() + if err == nil { + t.Fatalf("non-nil error expected from Recv()") + } + code := status.Code(err) + if code == codes.Canceled { + cntCanceled++ + } + _, ok := str.Trailer()["a"] + if code == codes.PermissionDenied && !ok { + t.Fatalf(`status err: %v; wanted key "a" in trailer but didn't get it`, err) + } + if code == codes.Canceled && ok { + t.Fatalf(`status err: %v; didn't want key "a" in trailer but got it`, err) + } + } + if cntCanceled < 5 || cntPermDenied() < 5 { + t.Fatalf("got Canceled status %v times and PermissionDenied status %v times but wanted both of them at least 5 times", cntCanceled, cntPermDenied()) + } +}