diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 05dfb9b0d..497c78c61 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -152,7 +152,7 @@ func isTemporary(err error) bool { // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts ConnectOptions, onSuccess func()) (_ ClientTransport, err error) { +func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts ConnectOptions, onSuccess func()) (_ *http2Client, err error) { scheme := "http" ctx, cancel := context.WithCancel(ctx) defer func() { diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index cbf20386d..d83f8267e 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -396,30 +396,26 @@ func setUpServerOnly(t *testing.T, port int, serverConfig *ServerConfig, ht hTyp return server } -func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport, func()) { +func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) { return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{}, func() {}) } -func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hType, copts ConnectOptions, onHandshake func()) (*server, ClientTransport, func()) { +func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hType, copts ConnectOptions, onHandshake func()) (*server, *http2Client, func()) { server := setUpServerOnly(t, port, serverConfig, ht) addr := "localhost:" + server.port - var ( - ct ClientTransport - connErr error - ) target := TargetInfo{ Addr: addr, } connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) - ct, connErr = NewClientTransport(connectCtx, context.Background(), target, copts, onHandshake) + ct, connErr := NewClientTransport(connectCtx, context.Background(), target, copts, onHandshake) if connErr != nil { cancel() // Do not cancel in success path. t.Fatalf("failed to create transport: %v", connErr) } - return server, ct, cancel + return server, ct.(*http2Client), cancel } -func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Conn) (ClientTransport, func()) { +func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Conn) (*http2Client, func()) { lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen: %v", err) @@ -446,7 +442,7 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Con } t.Fatalf("Failed to dial: %v", err) } - return tr, cancel + return tr.(*http2Client), cancel } // TestInflightStreamClosing ensures that closing in-flight stream @@ -504,7 +500,7 @@ func TestMaxConnectionIdle(t *testing.T) { if err != nil { t.Fatalf("Client failed to create RPC request: %v", err) } - client.(*http2Client).closeStream(stream, io.EOF, true, http2.ErrCodeCancel, nil, nil, false) + client.closeStream(stream, io.EOF, true, http2.ErrCodeCancel, nil, nil, false) // wait for server to see that closed stream and max-age logic to send goaway after no new RPCs are mode timeout := time.NewTimer(time.Second * 4) select { @@ -635,10 +631,9 @@ func TestKeepaliveServerNegative(t *testing.T) { // Give keepalive logic some time by sleeping. time.Sleep(4 * time.Second) // Assert that client is still active. - clientTr := client.(*http2Client) - clientTr.mu.Lock() - defer clientTr.mu.Unlock() - if clientTr.state != reachable { + client.mu.Lock() + defer client.mu.Unlock() + if client.state != reachable { t.Fatalf("Test failed: Expected server-client connection to be healthy.") } } @@ -660,10 +655,9 @@ func TestKeepaliveClientClosesIdleTransport(t *testing.T) { // Sleep for keepalive to close the connection. time.Sleep(4 * time.Second) // Assert that the connection was closed. - ct := tr.(*http2Client) - ct.mu.Lock() - defer ct.mu.Unlock() - if ct.state == reachable { + tr.mu.Lock() + defer tr.mu.Unlock() + if tr.state == reachable { t.Fatalf("Test Failed: Expected client transport to have closed.") } } @@ -684,10 +678,9 @@ func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) { // Give keepalive some time. time.Sleep(4 * time.Second) // Assert that connections is still healthy. - ct := tr.(*http2Client) - ct.mu.Lock() - defer ct.mu.Unlock() - if ct.state != reachable { + tr.mu.Lock() + defer tr.mu.Unlock() + if tr.state != reachable { t.Fatalf("Test failed: Expected client transport to be healthy.") } } @@ -713,10 +706,9 @@ func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { // Give keepalive some time. time.Sleep(4 * time.Second) // Assert that transport was closed. - ct := tr.(*http2Client) - ct.mu.Lock() - defer ct.mu.Unlock() - if ct.state == reachable { + tr.mu.Lock() + defer tr.mu.Unlock() + if tr.state == reachable { t.Fatalf("Test failed: Expected client transport to have closed.") } } @@ -733,10 +725,9 @@ func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { // Give keep alive some time. time.Sleep(4 * time.Second) // Assert that transport is healthy. - ct := tr.(*http2Client) - ct.mu.Lock() - defer ct.mu.Unlock() - if ct.state != reachable { + tr.mu.Lock() + defer tr.mu.Unlock() + if tr.state != reachable { t.Fatalf("Test failed: Expected client transport to be healthy.") } } @@ -769,10 +760,9 @@ func TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { t.Fatalf("Test failed: Expected a GoAway from server.") } time.Sleep(500 * time.Millisecond) - ct := client.(*http2Client) - ct.mu.Lock() - defer ct.mu.Unlock() - if ct.state == reachable { + client.mu.Lock() + defer client.mu.Unlock() + if client.state == reachable { t.Fatalf("Test failed: Expected the connection to be closed.") } } @@ -807,10 +797,9 @@ func TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { t.Fatalf("Test failed: Expected a GoAway from server.") } time.Sleep(500 * time.Millisecond) - ct := client.(*http2Client) - ct.mu.Lock() - defer ct.mu.Unlock() - if ct.state == reachable { + client.mu.Lock() + defer client.mu.Unlock() + if client.state == reachable { t.Fatalf("Test failed: Expected the connection to be closed.") } } @@ -837,10 +826,9 @@ func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { // Give keepalive enough time. time.Sleep(3 * time.Second) // Assert that connection is healthy. - ct := client.(*http2Client) - ct.mu.Lock() - defer ct.mu.Unlock() - if ct.state != reachable { + client.mu.Lock() + defer client.mu.Unlock() + if client.state != reachable { t.Fatalf("Test failed: Expected connection to be healthy.") } } @@ -869,10 +857,9 @@ func TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) { // Give keepalive enough time. time.Sleep(3 * time.Second) // Assert that connection is healthy. - ct := client.(*http2Client) - ct.mu.Lock() - defer ct.mu.Unlock() - if ct.state != reachable { + client.mu.Lock() + defer client.mu.Unlock() + if client.state != reachable { t.Fatalf("Test failed: Expected connection to be healthy.") } } @@ -1259,10 +1246,9 @@ func TestMaxStreams(t *testing.T) { ct.CloseStream(s, nil) <-done ct.Close() - cc := ct.(*http2Client) - <-cc.writerDone - if cc.maxConcurrentStreams != 1 { - t.Fatalf("cc.maxConcurrentStreams: %d, want 1", cc.maxConcurrentStreams) + <-ct.writerDone + if ct.maxConcurrentStreams != 1 { + t.Fatalf("ct.maxConcurrentStreams: %d, want 1", ct.maxConcurrentStreams) } } @@ -1292,15 +1278,11 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) { server.mu.Unlock() break } - cc, ok := ct.(*http2Client) - if !ok { - t.Fatalf("Failed to convert %v to *http2Client", ct) - } s, err := ct.NewStream(context.Background(), callHdr) if err != nil { t.Fatalf("Failed to open stream: %v", err) } - cc.controlBuf.put(&dataFrame{ + ct.controlBuf.put(&dataFrame{ streamID: s.id, endStream: false, h: nil, @@ -1320,7 +1302,7 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) { sc.mu.Unlock() break } - cc.Close() + ct.Close() select { case <-ss.Context().Done(): if ss.Context().Err() != context.Canceled { @@ -1835,7 +1817,6 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) st = k.(*http2Server) } server.mu.Unlock() - ct := client.(*http2Client) const numStreams = 10 clientStreams := make([]*Stream, numStreams) for i := 0; i < numStreams; i++ { @@ -1857,7 +1838,7 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) opts := Options{} header := make([]byte, 5) for i := 1; i <= 10; i++ { - if err := ct.Write(stream, nil, buf, &opts); err != nil { + if err := client.Write(stream, nil, buf, &opts); err != nil { t.Errorf("Error on client while writing message: %v", err) return } @@ -1888,25 +1869,25 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) id := stream.id serverStreams[id] = st.activeStreams[id] loopyServerStreams[id] = st.loopy.estdStreams[id] - loopyClientStreams[id] = ct.loopy.estdStreams[id] + loopyClientStreams[id] = client.loopy.estdStreams[id] } st.mu.Unlock() // Close all streams for _, stream := range clientStreams { - ct.Write(stream, nil, nil, &Options{Last: true}) + client.Write(stream, nil, nil, &Options{Last: true}) if _, err := stream.Read(make([]byte, 5)); err != io.EOF { t.Fatalf("Client expected an EOF from the server. Got: %v", err) } } // Close down both server and client so that their internals can be read without data // races. - ct.Close() + client.Close() st.Close() <-st.readerDone <-st.writerDone - <-ct.readerDone - <-ct.writerDone + <-client.readerDone + <-client.writerDone for _, cstream := range clientStreams { id := cstream.id sstream := serverStreams[id] @@ -1916,16 +1897,16 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig) if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding { t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding) } - if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(ct.loopy.oiws)-loopyClientStream.bytesOutStanding { - t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, ct.loopy.oiws, loopyClientStream.bytesOutStanding) + if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(client.loopy.oiws)-loopyClientStream.bytesOutStanding { + t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, client.loopy.oiws, loopyClientStream.bytesOutStanding) } } // Check transport flow control. - if ct.fc.limit != ct.fc.unacked+st.loopy.sendQuota { - t.Fatalf("Account mismatch: client transport inflow(%d) != client unacked(%d) + server sendQuota(%d)", ct.fc.limit, ct.fc.unacked, st.loopy.sendQuota) + if client.fc.limit != client.fc.unacked+st.loopy.sendQuota { + t.Fatalf("Account mismatch: client transport inflow(%d) != client unacked(%d) + server sendQuota(%d)", client.fc.limit, client.fc.unacked, st.loopy.sendQuota) } - if st.fc.limit != st.fc.unacked+ct.loopy.sendQuota { - t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, ct.loopy.sendQuota) + if st.fc.limit != st.fc.unacked+client.loopy.sendQuota { + t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, client.loopy.sendQuota) } } @@ -2216,7 +2197,6 @@ func runPingPongTest(t *testing.T, msgSize int) { } return false, nil }) - ct := client.(*http2Client) stream, err := client.NewStream(context.Background(), &CallHdr{}) if err != nil { t.Fatalf("Failed to create stream. Err: %v", err) @@ -2236,13 +2216,13 @@ func runPingPongTest(t *testing.T, msgSize int) { for { select { case <-done: - ct.Write(stream, nil, nil, &Options{Last: true}) + client.Write(stream, nil, nil, &Options{Last: true}) if _, err := stream.Read(incomingHeader); err != io.EOF { t.Fatalf("Client expected EOF from the server. Got: %v", err) } return default: - if err := ct.Write(stream, outgoingHeader, msg, opts); err != nil { + if err := client.Write(stream, outgoingHeader, msg, opts); err != nil { t.Fatalf("Error on client while writing message. Err: %v", err) } if _, err := stream.Read(incomingHeader); err != nil { @@ -2344,7 +2324,7 @@ func TestHeaderTblSize(t *testing.T) { t.Fatalf("expected len(limits) = 1 within 10s, got != 1") } - ct.(*http2Client).controlBuf.put(&outgoingSettings{ + ct.controlBuf.put(&outgoingSettings{ ss: []http2.Setting{ { ID: http2.SettingHeaderTableSize,