diff --git a/test/end2end_test.go b/test/end2end_test.go index 3f55024b3..cf1e29154 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -795,17 +795,23 @@ func testExceedMaxStreamsLimit(t *testing.T, e env) { s, cc := setUp(1, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) - var err error - for { - time.Sleep(2 * time.Millisecond) - _, err = tc.StreamingInputCall(context.Background()) - // Loop until the settings of max concurrent streams is - // received by the client. - if err != nil { - break + // Perform an unary RPC to make sure the new settings were propagated to the client. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", tc, err) + } + // Initiate the 1st stream + if _, err := tc.StreamingInputCall(context.Background()); err != nil { + t.Fatalf("%v.StreamingInputCall(_) = %v, want ", tc, err) + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + // The 2nd stream should block until its deadline exceeds. + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if _, err := tc.StreamingInputCall(ctx); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("%v.StreamingInputCall(%v) = _, %v, want error code %d", tc, ctx, err, codes.DeadlineExceeded) } - } - if grpc.Code(err) != codes.Unavailable { - t.Fatalf("got %v, want error code %d", err, codes.Unavailable) - } + wg.Done() + }() + wg.Wait() } diff --git a/transport/http2_client.go b/transport/http2_client.go index 6ba934489..35a69fc30 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -79,6 +79,8 @@ type http2Client struct { fc *inFlow // sendQuotaPool provides flow control to outbound message. sendQuotaPool *quotaPool + // streamsQuota limits the max number of concurrent streams. + streamsQuota *quotaPool // The scheme used: https if TLS is on, http otherwise. scheme string @@ -89,7 +91,7 @@ type http2Client struct { state transportState // the state of underlying connection activeStreams map[uint32]*Stream // The max number of concurrent streams - maxStreams uint32 + maxStreams int // the per-stream outbound flow control window size set by the peer. streamSendQuota uint32 } @@ -174,8 +176,8 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e scheme: scheme, state: reachable, activeStreams: make(map[uint32]*Stream), - maxStreams: math.MaxUint32, authCreds: opts.AuthOptions, + maxStreams: math.MaxInt32, streamSendQuota: defaultWindowSize, } go t.controller() @@ -236,19 +238,27 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea authData[k] = v } } - if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { - return nil, err - } t.mu.Lock() if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing } - if uint32(len(t.activeStreams)) >= t.maxStreams { - t.mu.Unlock() - t.writableChan <- 0 - return nil, StreamErrorf(codes.Unavailable, "transport: failed to create new stream because the limit has been reached.") + if t.streamsQuota != nil { + q, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire()) + if err != nil { + t.mu.Unlock() + return nil, err + } + // Returns the quota balance back. + if q > 1 { + t.streamsQuota.add(q - 1) + } } + t.mu.Unlock() + if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { + return nil, err + } + t.mu.Lock() s := t.newStream(ctx, callHdr) t.activeStreams[s.id] = s t.mu.Unlock() @@ -318,6 +328,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea func (t *http2Client) CloseStream(s *Stream, err error) { t.mu.Lock() delete(t.activeStreams, s.id) + if t.streamsQuota != nil { + t.streamsQuota.add(1) + } t.mu.Unlock() s.mu.Lock() if q := s.fc.restoreConn(); q > 0 { @@ -558,7 +571,18 @@ func (t *http2Client) handleSettings(f *http2.SettingsFrame) { defer t.mu.Unlock() switch s.ID { case http2.SettingMaxConcurrentStreams: - t.maxStreams = v + // TODO(zhaoq): This is a hack to avoid significant refactoring of the + // code to deal with the unrealistic int32 overflow. Probably will try + // to find a better way to handle this later. + if v > math.MaxInt32 { + v = math.MaxInt32 + } + if t.streamsQuota == nil { + t.streamsQuota = newQuotaPool(int(v)) + } else { + t.streamsQuota.reset(int(v) - t.maxStreams) + } + t.maxStreams = int(v) case http2.SettingInitialWindowSize: for _, s := range t.activeStreams { // Adjust the sending quota for each s. diff --git a/transport/transport_test.go b/transport/transport_test.go index adbb3e002..fafbbf23b 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -299,54 +299,6 @@ func TestClientMix(t *testing.T) { } } -func TestExceedMaxStreamsLimit(t *testing.T) { - server, ct := setUp(t, 0, 1, normal) - defer func() { - ct.Close() - server.stop() - }() - callHdr := &CallHdr{ - Host: "localhost", - Method: "foo.Small", - } - // Creates the 1st stream and keep it alive. - _, err1 := ct.NewStream(context.Background(), callHdr) - if err1 != nil { - t.Fatalf("failed to open stream: %v", err1) - } - // Creates the 2nd stream. It has chance to succeed when the settings - // frame from the server has not received at the client. - s, err2 := ct.NewStream(context.Background(), callHdr) - if err2 != nil { - se, ok := err2.(StreamError) - if !ok { - t.Fatalf("Received unexpected error %v", err2) - } - if se.Code != codes.Unavailable { - t.Fatalf("Got error code: %d, want: %d", se.Code, codes.Unavailable) - } - return - } - // If the 2nd stream is created successfully, sends the request. - if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil { - t.Fatalf("failed to send data: %v", err) - } - // The 2nd stream was rejected by the server via a reset. - p := make([]byte, len(expectedResponse)) - _, recvErr := io.ReadFull(s, p) - if recvErr != io.EOF || s.StatusCode() != codes.Unavailable { - t.Fatalf("Error: %v, StatusCode: %d; want , %d", recvErr, s.StatusCode(), codes.Unavailable) - } - // Server's setting has been received. From now on, new stream will be rejected instantly. - _, err3 := ct.NewStream(context.Background(), callHdr) - if err3 == nil { - t.Fatalf("Received unexpected , want an error with code %d", codes.Unavailable) - } - if se, ok := err3.(StreamError); !ok || se.Code != codes.Unavailable { - t.Fatalf("Got: %v, want a StreamError with error code %d", err3, codes.Unavailable) - } -} - func TestLargeMessage(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, normal) callHdr := &CallHdr{