cleanup: Remove test contexts without timeouts (#8072)

This commit is contained in:
Arjan Singh Bal 2025-02-12 00:39:01 +05:30 committed by GitHub
parent e95a4b7136
commit ad5cd321d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 167 additions and 106 deletions

View File

@ -22,6 +22,7 @@ import (
"context" "context"
"net" "net"
"testing" "testing"
"time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
@ -31,6 +32,8 @@ import (
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
) )
var defaultTestTimeout = 10 * time.Second
type s struct { type s struct {
grpctest.Tester grpctest.Tester
} }
@ -103,6 +106,8 @@ func overrideNewCredsFuncs() func() {
// modes), ClientHandshake does either tls or alts base on the cluster name in // modes), ClientHandshake does either tls or alts base on the cluster name in
// attributes. // attributes.
func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) { func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
defer overrideNewCredsFuncs()() defer overrideNewCredsFuncs()()
for bundleTyp, tc := range map[string]credentials.Bundle{ for bundleTyp, tc := range map[string]credentials.Bundle{
"defaultCredsWithOptions": NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}), "defaultCredsWithOptions": NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}),
@ -116,12 +121,12 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
}{ }{
{ {
name: "no cluster name", name: "no cluster name",
ctx: context.Background(), ctx: ctx,
wantTyp: "tls", wantTyp: "tls",
}, },
{ {
name: "with non-CFE cluster name", name: "with non-CFE cluster name",
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes, Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
}), }),
// non-CFE backends should use alts. // non-CFE backends should use alts.
@ -129,7 +134,7 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
}, },
{ {
name: "with CFE cluster name", name: "with CFE cluster name",
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes, Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes,
}), }),
// CFE should use tls. // CFE should use tls.
@ -137,7 +142,7 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
}, },
{ {
name: "with xdstp CFE cluster name", name: "with xdstp CFE cluster name",
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes, Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
}), }),
// CFE should use tls. // CFE should use tls.
@ -145,7 +150,7 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
}, },
{ {
name: "with xdstp non-CFE cluster name", name: "with xdstp non-CFE cluster name",
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://other.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes, Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://other.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
}), }),
// non-CFE should use atls. // non-CFE should use atls.
@ -176,6 +181,8 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
} }
func TestDefaultCredentialsWithOptions(t *testing.T) { func TestDefaultCredentialsWithOptions(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
md1 := map[string]string{"foo": "tls"} md1 := map[string]string{"foo": "tls"}
md2 := map[string]string{"foo": "alts"} md2 := map[string]string{"foo": "alts"}
tests := []struct { tests := []struct {
@ -248,7 +255,7 @@ func TestDefaultCredentialsWithOptions(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
bundle := NewDefaultCredentialsWithOptions(tc.defaultCredsOpts) bundle := NewDefaultCredentialsWithOptions(tc.defaultCredsOpts)
ri := credentials.RequestInfo{AuthInfo: tc.authInfo} ri := credentials.RequestInfo{AuthInfo: tc.authInfo}
ctx := icredentials.NewRequestInfoContext(context.Background(), ri) ctx := icredentials.NewRequestInfoContext(ctx, ri)
got, err := bundle.PerRPCCredentials().GetRequestMetadata(ctx, "uri") got, err := bundle.PerRPCCredentials().GetRequestMetadata(ctx, "uri")
if err != nil { if err != nil {
t.Fatalf("Bundle's PerRPCCredentials().GetRequestMetadata() unexpected error = %v", err) t.Fatalf("Bundle's PerRPCCredentials().GetRequestMetadata() unexpected error = %v", err)

View File

@ -29,8 +29,10 @@ import (
) )
func (s) TestIsDirectPathCluster(t *testing.T) { func (s) TestIsDirectPathCluster(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
c := func(cluster string) context.Context { c := func(cluster string) context.Context {
return icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{ return icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, cluster).Attributes, Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, cluster).Attributes,
}) })
} }
@ -40,7 +42,7 @@ func (s) TestIsDirectPathCluster(t *testing.T) {
ctx context.Context ctx context.Context
want bool want bool
}{ }{
{"not an xDS cluster", context.Background(), false}, {"not an xDS cluster", ctx, false},
{"cfe", c("google_cfe_bigtable.googleapis.com"), false}, {"cfe", c("google_cfe_bigtable.googleapis.com"), false},
{"non-cfe", c("google_bigtable.googleapis.com"), true}, {"non-cfe", c("google_bigtable.googleapis.com"), true},
{"starts with xdstp but not cfe format", c("xdstp:google_cfe_bigtable.googleapis.com"), true}, {"starts with xdstp but not cfe format", c("xdstp:google_cfe_bigtable.googleapis.com"), true},

View File

@ -312,8 +312,10 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
st.bodyw.Close() // no body st.bodyw.Close() // no body
s.WriteStatus(status.New(codes.OK, "")) s.WriteStatus(status.New(codes.OK, ""))
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
st.ht.HandleStreams( st.ht.HandleStreams(
context.Background(), func(s *ServerStream) { go handleStream(s) }, ctx, func(s *ServerStream) { go handleStream(s) },
) )
wantHeader := http.Header{ wantHeader := http.Header{
"Date": nil, "Date": nil,
@ -345,8 +347,10 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
handleStream := func(s *ServerStream) { handleStream := func(s *ServerStream) {
s.WriteStatus(status.New(statusCode, msg)) s.WriteStatus(status.New(statusCode, msg))
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
st.ht.HandleStreams( st.ht.HandleStreams(
context.Background(), func(s *ServerStream) { go handleStream(s) }, ctx, func(s *ServerStream) { go handleStream(s) },
) )
wantHeader := http.Header{ wantHeader := http.Header{
"Date": nil, "Date": nil,
@ -394,8 +398,10 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
} }
s.WriteStatus(status.New(codes.DeadlineExceeded, "too slow")) s.WriteStatus(status.New(codes.DeadlineExceeded, "too slow"))
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ht.HandleStreams( ht.HandleStreams(
context.Background(), func(s *ServerStream) { go runStream(s) }, ctx, func(s *ServerStream) { go runStream(s) },
) )
wantHeader := http.Header{ wantHeader := http.Header{
"Date": nil, "Date": nil,
@ -446,8 +452,10 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) { func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) {
st := newHandleStreamTest(t) st := newHandleStreamTest(t)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
t.Cleanup(cancel)
st.ht.HandleStreams( st.ht.HandleStreams(
context.Background(), func(s *ServerStream) { go handleStream(st, s) }, ctx, func(s *ServerStream) { go handleStream(st, s) },
) )
} }
@ -479,8 +487,10 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
handleStream := func(s *ServerStream) { handleStream := func(s *ServerStream) {
s.WriteStatus(st) s.WriteStatus(st)
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
hst.ht.HandleStreams( hst.ht.HandleStreams(
context.Background(), func(s *ServerStream) { go handleStream(s) }, ctx, func(s *ServerStream) { go handleStream(s) },
) )
wantHeader := http.Header{ wantHeader := http.Header{
"Date": nil, "Date": nil,

View File

@ -381,21 +381,23 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
h := &testStreamHandler{t: transport.(*http2Server)} h := &testStreamHandler{t: transport.(*http2Server)}
s.h = h s.h = h
s.mu.Unlock() s.mu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
switch ht { switch ht {
case notifyCall: case notifyCall:
go transport.HandleStreams(context.Background(), h.handleStreamAndNotify) go transport.HandleStreams(ctx, h.handleStreamAndNotify)
case suspended: case suspended:
go transport.HandleStreams(context.Background(), func(*ServerStream) {}) go transport.HandleStreams(ctx, func(*ServerStream) {})
case misbehaved: case misbehaved:
go transport.HandleStreams(context.Background(), func(s *ServerStream) { go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamMisbehave(t, s) go h.handleStreamMisbehave(t, s)
}) })
case encodingRequiredStatus: case encodingRequiredStatus:
go transport.HandleStreams(context.Background(), func(s *ServerStream) { go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamEncodingRequiredStatus(s) go h.handleStreamEncodingRequiredStatus(s)
}) })
case invalidHeaderField: case invalidHeaderField:
go transport.HandleStreams(context.Background(), func(s *ServerStream) { go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamInvalidHeaderField(s) go h.handleStreamInvalidHeaderField(s)
}) })
case delayRead: case delayRead:
@ -404,15 +406,15 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.mu.Lock() s.mu.Lock()
close(s.ready) close(s.ready)
s.mu.Unlock() s.mu.Unlock()
go transport.HandleStreams(context.Background(), func(s *ServerStream) { go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamDelayRead(t, s) go h.handleStreamDelayRead(t, s)
}) })
case pingpong: case pingpong:
go transport.HandleStreams(context.Background(), func(s *ServerStream) { go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStreamPingPong(t, s) go h.handleStreamPingPong(t, s)
}) })
default: default:
go transport.HandleStreams(context.Background(), func(s *ServerStream) { go transport.HandleStreams(ctx, func(s *ServerStream) {
go h.handleStream(t, s) go h.handleStream(t, s)
}) })
} }
@ -464,13 +466,15 @@ func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts
addr := resolver.Address{Addr: "localhost:" + server.port} addr := resolver.Address{Addr: "localhost:" + server.port}
copts.ChannelzParent = channelzSubChannel(t) copts.ChannelzParent = channelzSubChannel(t)
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
ct, connErr := NewHTTP2Client(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {}) t.Cleanup(cancel)
connectCtx, cCancel := context.WithTimeout(context.Background(), 2*time.Second)
ct, connErr := NewHTTP2Client(connectCtx, ctx, addr, copts, func(GoAwayReason) {})
if connErr != nil { if connErr != nil {
cancel() // Do not cancel in success path. cCancel() // Do not cancel in success path.
t.Fatalf("failed to create transport: %v", connErr) t.Fatalf("failed to create transport: %v", connErr)
} }
return server, ct.(*http2Client), cancel return server, ct.(*http2Client), cCancel
} }
func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.Conn) (*http2Client, func()) { func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.Conn) (*http2Client, func()) {
@ -495,10 +499,12 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.C
} }
connCh <- conn connCh <- conn
}() }()
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
tr, err := NewHTTP2Client(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) t.Cleanup(cancel)
connectCtx, cCancel := context.WithTimeout(context.Background(), 2*time.Second)
tr, err := NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
if err != nil { if err != nil {
cancel() // Do not cancel in success path. cCancel() // Do not cancel in success path.
// Server clean-up. // Server clean-up.
lis.Close() lis.Close()
if conn, ok := <-connCh; ok { if conn, ok := <-connCh; ok {
@ -506,7 +512,7 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.C
} }
t.Fatalf("Failed to dial: %v", err) t.Fatalf("Failed to dial: %v", err)
} }
return tr.(*http2Client), cancel return tr.(*http2Client), cCancel
} }
// TestInflightStreamClosing ensures that closing in-flight stream // TestInflightStreamClosing ensures that closing in-flight stream
@ -739,7 +745,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
Host: "localhost", Host: "localhost",
Method: "foo.Large", Method: "foo.Large",
} }
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
s, err := ct.NewStream(ctx, callHdr) s, err := ct.NewStream(ctx, callHdr)
if err != nil { if err != nil {
@ -833,7 +839,7 @@ func (s) TestGracefulClose(t *testing.T) {
// Correctly clean up the server // Correctly clean up the server
server.stop() server.stop()
}() }()
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
// Create a stream that will exist for this whole test and confirm basic // Create a stream that will exist for this whole test and confirm basic
@ -969,7 +975,7 @@ func (s) TestMaxStreams(t *testing.T) {
// Try and create a new stream. // Try and create a new stream.
go func() { go func() {
defer close(done) defer close(done)
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
if _, err := ct.NewStream(ctx, callHdr); err != nil { if _, err := ct.NewStream(ctx, callHdr); err != nil {
t.Errorf("Failed to open stream: %v", err) t.Errorf("Failed to open stream: %v", err)
@ -1353,7 +1359,9 @@ func (s) TestClientHonorsConnectContext(t *testing.T) {
parent := channelzSubChannel(t) parent := channelzSubChannel(t)
copts := ConnectOptions{ChannelzParent: parent} copts := ConnectOptions{ChannelzParent: parent}
_, err = NewHTTP2Client(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err = NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
if err == nil { if err == nil {
t.Fatalf("NewHTTP2Client() returned successfully; wanted error") t.Fatalf("NewHTTP2Client() returned successfully; wanted error")
} }
@ -1365,7 +1373,7 @@ func (s) TestClientHonorsConnectContext(t *testing.T) {
// Test context deadline. // Test context deadline.
connectCtx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) connectCtx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
_, err = NewHTTP2Client(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) _, err = NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
if err == nil { if err == nil {
t.Fatalf("NewHTTP2Client() returned successfully; wanted error") t.Fatalf("NewHTTP2Client() returned successfully; wanted error")
} }
@ -1440,12 +1448,14 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
} }
} }
}() }()
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) connectCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
parent := channelzSubChannel(t) parent := channelzSubChannel(t)
copts := ConnectOptions{ChannelzParent: parent} copts := ConnectOptions{ChannelzParent: parent}
ct, err := NewHTTP2Client(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ct, err := NewHTTP2Client(connectCtx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
if err != nil { if err != nil {
t.Fatalf("Error while creating client transport: %v", err) t.Fatalf("Error while creating client transport: %v", err)
} }
@ -1779,9 +1789,11 @@ func waitWhileTrue(t *testing.T, condition func() (bool, error)) {
// If any error occurs on a call to Stream.Read, future calls // If any error occurs on a call to Stream.Read, future calls
// should continue to return that same error. // should continue to return that same error.
func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
testRecvBuffer := newRecvBuffer() testRecvBuffer := newRecvBuffer()
s := &Stream{ s := &Stream{
ctx: context.Background(), ctx: ctx,
buf: testRecvBuffer, buf: testRecvBuffer,
requestRead: func(int) {}, requestRead: func(int) {},
} }
@ -2450,7 +2462,7 @@ func (s) TestClientHandshakeInfo(t *testing.T) {
Addr: "localhost:" + server.port, Addr: "localhost:" + server.port,
Attributes: attributes.New(testAttrKey, testAttrVal), Attributes: attributes.New(testAttrKey, testAttrVal),
} }
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
creds := &attrTransportCreds{} creds := &attrTransportCreds{}
@ -2485,7 +2497,7 @@ func (s) TestClientHandshakeInfoDialer(t *testing.T) {
Addr: "localhost:" + server.port, Addr: "localhost:" + server.port,
Attributes: attributes.New(testAttrKey, testAttrVal), Attributes: attributes.New(testAttrKey, testAttrVal),
} }
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second)) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
var attr *attributes.Attributes var attr *attributes.Attributes
@ -2829,7 +2841,7 @@ func (s) TestClientCloseReturnsAfterReaderCompletes(t *testing.T) {
// Create a client transport with a custom dialer that hangs the Read() // Create a client transport with a custom dialer that hangs the Read()
// after Close(). // after Close().
ct, err := NewHTTP2Client(ctx, context.Background(), addr, copts, func(GoAwayReason) {}) ct, err := NewHTTP2Client(ctx, ctx, addr, copts, func(GoAwayReason) {})
if err != nil { if err != nil {
t.Fatalf("Failed to create transport: %v", err) t.Fatalf("Failed to create transport: %v", err)
} }
@ -2915,14 +2927,14 @@ func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) {
} }
copts := ConnectOptions{Dialer: dialer} copts := ConnectOptions{Dialer: dialer}
copts.ChannelzParent = channelzSubChannel(t) copts.ChannelzParent = channelzSubChannel(t)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Create client transport with custom dialer // Create client transport with custom dialer
ct, connErr := NewHTTP2Client(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {}) ct, connErr := NewHTTP2Client(connectCtx, ctx, addr, copts, func(GoAwayReason) {})
if connErr != nil { if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr) t.Fatalf("failed to create transport: %v", connErr)
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil { if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil {
t.Fatalf("Failed to open stream: %v", err) t.Fatalf("Failed to open stream: %v", err)
} }

View File

@ -77,10 +77,7 @@ git grep '"github.com/envoyproxy/go-control-plane/envoy' -- '*.go' ':(exclude)*.
# - Ensure all context usages are done with timeout. # - Ensure all context usages are done with timeout.
# Context tests under benchmark are excluded as they are testing the performance of context.Background() and context.TODO(). # Context tests under benchmark are excluded as they are testing the performance of context.Background() and context.TODO().
# TODO: Remove the exclusions once the tests are updated to use context.WithTimeout(). git grep -e 'context.Background()' --or -e 'context.TODO()' -- "*_test.go" | grep -v "benchmark/primitives/context_test.go" | grep -v 'context.WithTimeout(' | not grep -v 'context.WithCancel('
# See https://github.com/grpc/grpc-go/issues/7304
git grep -e 'context.Background()' --or -e 'context.TODO()' -- "*_test.go" | grep -v "benchmark/primitives/context_test.go" | grep -v "credential
s/google" | grep -v "internal/transport/" | grep -v "xds/internal/" | grep -v "security/advancedtls" | grep -v 'context.WithTimeout(' | not grep -v 'context.WithCancel('
# Disallow usage of net.ParseIP in favour of netip.ParseAddr as the former # Disallow usage of net.ParseIP in favour of netip.ParseAddr as the former
# can't parse link local IPv6 addresses. # can't parse link local IPv6 addresses.

View File

@ -872,8 +872,9 @@ func (s) TestClientServerHandshake(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("NewClientCreds failed: %v", err) t.Fatalf("NewClientCreds failed: %v", err)
} }
_, clientAuthInfo, handshakeErr := clientTLS.ClientHandshake(context.Background(), ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
lisAddr, conn) defer cancel()
_, clientAuthInfo, handshakeErr := clientTLS.ClientHandshake(ctx, lisAddr, conn)
// wait until server sends serverAuthInfo or fails. // wait until server sends serverAuthInfo or fails.
serverAuthInfo, ok := <-done serverAuthInfo, ok := <-done
if !ok && test.serverExpectError { if !ok && test.serverExpectError {
@ -906,7 +907,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
cert, _ := test.clientGetCert(&tls.CertificateRequestInfo{}) cert, _ := test.clientGetCert(&tls.CertificateRequestInfo{})
clientCert = cert clientCert = cert
} else if test.clientIdentityProvider != nil { } else if test.clientIdentityProvider != nil {
km, _ := test.clientIdentityProvider.KeyMaterial(context.TODO()) km, _ := test.clientIdentityProvider.KeyMaterial(ctx)
clientCert = &km.Certs[0] clientCert = &km.Certs[0]
} }
if !bytes.Equal((*serverVerifiedChains[0][0]).Raw, clientCert.Certificate[0]) { if !bytes.Equal((*serverVerifiedChains[0][0]).Raw, clientCert.Certificate[0]) {
@ -920,7 +921,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
result, _ := test.serverGetRoot(&ConnectionInfo{}) result, _ := test.serverGetRoot(&ConnectionInfo{})
serverRoot = result.TrustCerts serverRoot = result.TrustCerts
} else if test.serverRootProvider != nil { } else if test.serverRootProvider != nil {
km, _ := test.serverRootProvider.KeyMaterial(context.TODO()) km, _ := test.serverRootProvider.KeyMaterial(ctx)
serverRoot = km.Roots serverRoot = km.Roots
} }
serverVerifiedChainsCp := x509.NewCertPool() serverVerifiedChainsCp := x509.NewCertPool()
@ -941,7 +942,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
cert, _ := test.serverGetCert(&tls.ClientHelloInfo{}) cert, _ := test.serverGetCert(&tls.ClientHelloInfo{})
serverCert = cert[0] serverCert = cert[0]
} else if test.serverIdentityProvider != nil { } else if test.serverIdentityProvider != nil {
km, _ := test.serverIdentityProvider.KeyMaterial(context.TODO()) km, _ := test.serverIdentityProvider.KeyMaterial(ctx)
serverCert = &km.Certs[0] serverCert = &km.Certs[0]
} }
if !bytes.Equal((*clientVerifiedChains[0][0]).Raw, serverCert.Certificate[0]) { if !bytes.Equal((*clientVerifiedChains[0][0]).Raw, serverCert.Certificate[0]) {
@ -955,7 +956,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
result, _ := test.clientGetRoot(&ConnectionInfo{}) result, _ := test.clientGetRoot(&ConnectionInfo{})
clientRoot = result.TrustCerts clientRoot = result.TrustCerts
} else if test.clientRootProvider != nil { } else if test.clientRootProvider != nil {
km, _ := test.clientRootProvider.KeyMaterial(context.TODO()) km, _ := test.clientRootProvider.KeyMaterial(ctx)
clientRoot = km.Roots clientRoot = km.Roots
} }
clientVerifiedChainsCp := x509.NewCertPool() clientVerifiedChainsCp := x509.NewCertPool()

View File

@ -121,6 +121,8 @@ func TestClusterPicks(t *testing.T) {
sc.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) sc.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
p1 := <-cc.NewPickerCh p1 := <-cc.NewPickerCh
for _, tt := range []struct { for _, tt := range []struct {
pickInfo balancer.PickInfo pickInfo balancer.PickInfo
@ -129,19 +131,19 @@ func TestClusterPicks(t *testing.T) {
}{ }{
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:cluster_1"), Ctx: SetPickedCluster(ctx, "cds:cluster_1"),
}, },
wantSC: m1[wantAddrs[0]], wantSC: m1[wantAddrs[0]],
}, },
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:cluster_2"), Ctx: SetPickedCluster(ctx, "cds:cluster_2"),
}, },
wantSC: m1[wantAddrs[1]], wantSC: m1[wantAddrs[1]],
}, },
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "notacluster"), Ctx: SetPickedCluster(ctx, "notacluster"),
}, },
wantErr: status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "notacluster"`), wantErr: status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "notacluster"`),
}, },
@ -201,6 +203,8 @@ func TestConfigUpdateAddCluster(t *testing.T) {
} }
p1 := <-cc.NewPickerCh p1 := <-cc.NewPickerCh
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, tt := range []struct { for _, tt := range []struct {
pickInfo balancer.PickInfo pickInfo balancer.PickInfo
wantSC balancer.SubConn wantSC balancer.SubConn
@ -208,19 +212,19 @@ func TestConfigUpdateAddCluster(t *testing.T) {
}{ }{
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:cluster_1"), Ctx: SetPickedCluster(ctx, "cds:cluster_1"),
}, },
wantSC: m1[wantAddrs[0]], wantSC: m1[wantAddrs[0]],
}, },
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:cluster_2"), Ctx: SetPickedCluster(ctx, "cds:cluster_2"),
}, },
wantSC: m1[wantAddrs[1]], wantSC: m1[wantAddrs[1]],
}, },
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:notacluster"), Ctx: SetPickedCluster(ctx, "cds:notacluster"),
}, },
wantErr: status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "cds:notacluster"`), wantErr: status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "cds:notacluster"`),
}, },
@ -281,25 +285,25 @@ func TestConfigUpdateAddCluster(t *testing.T) {
}{ }{
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:cluster_1"), Ctx: SetPickedCluster(ctx, "cds:cluster_1"),
}, },
wantSC: m1[wantAddrs[0]], wantSC: m1[wantAddrs[0]],
}, },
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:cluster_2"), Ctx: SetPickedCluster(ctx, "cds:cluster_2"),
}, },
wantSC: m1[wantAddrs[1]], wantSC: m1[wantAddrs[1]],
}, },
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:cluster_3"), Ctx: SetPickedCluster(ctx, "cds:cluster_3"),
}, },
wantSC: m1[wantAddrs[2]], wantSC: m1[wantAddrs[2]],
}, },
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:notacluster"), Ctx: SetPickedCluster(ctx, "cds:notacluster"),
}, },
wantErr: status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "cds:notacluster"`), wantErr: status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "cds:notacluster"`),
}, },
@ -359,6 +363,8 @@ func TestRoutingConfigUpdateDeleteAll(t *testing.T) {
} }
p1 := <-cc.NewPickerCh p1 := <-cc.NewPickerCh
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, tt := range []struct { for _, tt := range []struct {
pickInfo balancer.PickInfo pickInfo balancer.PickInfo
wantSC balancer.SubConn wantSC balancer.SubConn
@ -366,19 +372,19 @@ func TestRoutingConfigUpdateDeleteAll(t *testing.T) {
}{ }{
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:cluster_1"), Ctx: SetPickedCluster(ctx, "cds:cluster_1"),
}, },
wantSC: m1[wantAddrs[0]], wantSC: m1[wantAddrs[0]],
}, },
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:cluster_2"), Ctx: SetPickedCluster(ctx, "cds:cluster_2"),
}, },
wantSC: m1[wantAddrs[1]], wantSC: m1[wantAddrs[1]],
}, },
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:notacluster"), Ctx: SetPickedCluster(ctx, "cds:notacluster"),
}, },
wantErr: status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "cds:notacluster"`), wantErr: status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "cds:notacluster"`),
}, },
@ -409,7 +415,7 @@ func TestRoutingConfigUpdateDeleteAll(t *testing.T) {
p2 := <-cc.NewPickerCh p2 := <-cc.NewPickerCh
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
gotSCSt, err := p2.Pick(balancer.PickInfo{Ctx: SetPickedCluster(context.Background(), "cds:notacluster")}) gotSCSt, err := p2.Pick(balancer.PickInfo{Ctx: SetPickedCluster(ctx, "cds:notacluster")})
if fmt.Sprint(err) != status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "cds:notacluster"`).Error() { if fmt.Sprint(err) != status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "cds:notacluster"`).Error() {
t.Fatalf("picker.Pick, got %v, %v, want error %v", gotSCSt, err, `unknown cluster selected for RPC: "cds:notacluster"`) t.Fatalf("picker.Pick, got %v, %v, want error %v", gotSCSt, err, `unknown cluster selected for RPC: "cds:notacluster"`)
} }
@ -450,19 +456,19 @@ func TestRoutingConfigUpdateDeleteAll(t *testing.T) {
}{ }{
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:cluster_1"), Ctx: SetPickedCluster(ctx, "cds:cluster_1"),
}, },
wantSC: m2[wantAddrs[0]], wantSC: m2[wantAddrs[0]],
}, },
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:cluster_2"), Ctx: SetPickedCluster(ctx, "cds:cluster_2"),
}, },
wantSC: m2[wantAddrs[1]], wantSC: m2[wantAddrs[1]],
}, },
{ {
pickInfo: balancer.PickInfo{ pickInfo: balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "cds:notacluster"), Ctx: SetPickedCluster(ctx, "cds:notacluster"),
}, },
wantErr: status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "cds:notacluster"`), wantErr: status.Errorf(codes.Unavailable, `unknown cluster selected for RPC: "cds:notacluster"`),
}, },
@ -635,8 +641,10 @@ func TestClusterGracefulSwitch(t *testing.T) {
sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting}) sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
p1 := <-cc.NewPickerCh p1 := <-cc.NewPickerCh
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
pi := balancer.PickInfo{ pi := balancer.PickInfo{
Ctx: SetPickedCluster(context.Background(), "csp:cluster"), Ctx: SetPickedCluster(ctx, "csp:cluster"),
} }
testPick(t, p1, pi, sc1, nil) testPick(t, p1, pi, sc1, nil)
@ -676,8 +684,6 @@ func TestClusterGracefulSwitch(t *testing.T) {
// the pick first balancer to UpdateState() with CONNECTING, which shouldn't send // the pick first balancer to UpdateState() with CONNECTING, which shouldn't send
// a Picker update back, as the Graceful Switch process is not complete. // a Picker update back, as the Graceful Switch process is not complete.
sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting}) sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
defer cancel()
select { select {
case <-cc.NewPickerCh: case <-cc.NewPickerCh:
t.Fatalf("No new picker should have been sent due to the Graceful Switch process not completing") t.Fatalf("No new picker should have been sent due to the Graceful Switch process not completing")

View File

@ -105,11 +105,13 @@ func (s) TestPickerPickFirstTwo(t *testing.T) {
wantSCToConnect: testSubConns[1], wantSCToConnect: testSubConns[1],
}, },
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p := newPicker(tt.ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) p := newPicker(tt.ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
got, err := p.Pick(balancer.PickInfo{ got, err := p.Pick(balancer.PickInfo{
Ctx: SetRequestHash(context.Background(), tt.hash), Ctx: SetRequestHash(ctx, tt.hash),
}) })
if err != tt.wantErr { if err != tt.wantErr {
t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr)
@ -138,7 +140,9 @@ func (s) TestPickerPickTriggerTFConnect(t *testing.T) {
connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure,
}) })
p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
_, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, 5)})
if err == nil { if err == nil {
t.Fatalf("Pick() error = %v, want non-nil", err) t.Fatalf("Pick() error = %v, want non-nil", err)
} }
@ -168,7 +172,9 @@ func (s) TestPickerPickTriggerTFReturnReady(t *testing.T) {
connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Ready, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Ready,
}) })
p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
pr, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
pr, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, 5)})
if err != nil { if err != nil {
t.Fatalf("Pick() error = %v, want nil", err) t.Fatalf("Pick() error = %v, want nil", err)
} }
@ -194,7 +200,9 @@ func (s) TestPickerPickTriggerTFWithIdle(t *testing.T) {
connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.TransientFailure, connectivity.Idle, connectivity.TransientFailure, connectivity.TransientFailure,
}) })
p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test")) p := newPicker(ring, igrpclog.NewPrefixLogger(grpclog.Component("xds"), "rh_test"))
_, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(context.Background(), 5)}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
_, err := p.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, 5)})
if err == balancer.ErrNoSubConnAvailable { if err == balancer.ErrNoSubConnAvailable {
t.Fatalf("Pick() error = %v, want %v", err, balancer.ErrNoSubConnAvailable) t.Fatalf("Pick() error = %v, want %v", err, balancer.ErrNoSubConnAvailable)
} }

View File

@ -62,10 +62,6 @@ func init() {
} }
} }
func ctxWithHash(h uint64) context.Context {
return SetRequestHash(context.Background(), h)
}
// setupTest creates the balancer, and does an initial sanity check. // setupTest creates the balancer, and does an initial sanity check.
func setupTest(t *testing.T, addrs []resolver.Address) (*testutils.BalancerClientConn, balancer.Balancer, balancer.Picker) { func setupTest(t *testing.T, addrs []resolver.Address) (*testutils.BalancerClientConn, balancer.Balancer, balancer.Picker) {
t.Helper() t.Helper()
@ -153,7 +149,10 @@ func (s) TestOneSubConn(t *testing.T) {
testHash := firstHash - 1 testHash := firstHash - 1
// The first pick should be queued, and should trigger Connect() on the only // The first pick should be queued, and should trigger Connect() on the only
// SubConn. // SubConn.
if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
} }
sc0 := ring0.items[0].sc.sc.(*testutils.TestSubConn) sc0 := ring0.items[0].sc.sc.(*testutils.TestSubConn)
@ -170,7 +169,7 @@ func (s) TestOneSubConn(t *testing.T) {
// Test pick with one backend. // Test pick with one backend.
p1 := <-cc.NewPickerCh p1 := <-cc.NewPickerCh
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc0 { if gotSCSt.SubConn != sc0 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0)
} }
@ -196,7 +195,9 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
testHash := firstHash + 1 testHash := firstHash + 1
// The first pick should be queued, and should trigger Connect() on the only // The first pick should be queued, and should trigger Connect() on the only
// SubConn. // SubConn.
if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
} }
// The picked SubConn should be the second in the ring. // The picked SubConn should be the second in the ring.
@ -212,7 +213,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
p1 := <-cc.NewPickerCh p1 := <-cc.NewPickerCh
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc0 { if gotSCSt.SubConn != sc0 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0)
} }
@ -223,7 +224,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
p2 := <-cc.NewPickerCh p2 := <-cc.NewPickerCh
// Pick with the same hash should be queued, because the SubConn after the // Pick with the same hash should be queued, because the SubConn after the
// first picked is Idle. // first picked is Idle.
if _, err := p2.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { if _, err := p2.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
} }
@ -241,7 +242,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
// New picks should all return this SubConn. // New picks should all return this SubConn.
p3 := <-cc.NewPickerCh p3 := <-cc.NewPickerCh
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
gotSCSt, _ := p3.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) gotSCSt, _ := p3.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc1 { if gotSCSt.SubConn != sc1 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1)
} }
@ -263,7 +264,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready}) sc0.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
p4 := <-cc.NewPickerCh p4 := <-cc.NewPickerCh
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
gotSCSt, _ := p4.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) gotSCSt, _ := p4.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc0 { if gotSCSt.SubConn != sc0 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0)
} }
@ -289,7 +290,10 @@ func (s) TestThreeSubConnsAffinityMultiple(t *testing.T) {
testHash := firstHash + 1 testHash := firstHash + 1
// The first pick should be queued, and should trigger Connect() on the only // The first pick should be queued, and should trigger Connect() on the only
// SubConn. // SubConn.
if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}); err != balancer.ErrNoSubConnAvailable { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
} }
sc0 := ring0.items[1].sc.sc.(*testutils.TestSubConn) sc0 := ring0.items[1].sc.sc.(*testutils.TestSubConn)
@ -306,7 +310,7 @@ func (s) TestThreeSubConnsAffinityMultiple(t *testing.T) {
// First hash should always pick sc0. // First hash should always pick sc0.
p1 := <-cc.NewPickerCh p1 := <-cc.NewPickerCh
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc0 { if gotSCSt.SubConn != sc0 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0)
} }
@ -315,7 +319,7 @@ func (s) TestThreeSubConnsAffinityMultiple(t *testing.T) {
secondHash := ring0.items[1].hash secondHash := ring0.items[1].hash
// secondHash+1 will pick the third SubConn from the ring. // secondHash+1 will pick the third SubConn from the ring.
testHash2 := secondHash + 1 testHash2 := secondHash + 1
if _, err := p0.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash2)}); err != balancer.ErrNoSubConnAvailable { if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash2)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable) t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
} }
sc1 := ring0.items[2].sc.sc.(*testutils.TestSubConn) sc1 := ring0.items[2].sc.sc.(*testutils.TestSubConn)
@ -330,14 +334,14 @@ func (s) TestThreeSubConnsAffinityMultiple(t *testing.T) {
// With the new generated picker, hash2 always picks sc1. // With the new generated picker, hash2 always picks sc1.
p2 := <-cc.NewPickerCh p2 := <-cc.NewPickerCh
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash2)}) gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash2)})
if gotSCSt.SubConn != sc1 { if gotSCSt.SubConn != sc1 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1)
} }
} }
// But the first hash still picks sc0. // But the first hash still picks sc0.
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: ctxWithHash(testHash)}) gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc0 { if gotSCSt.SubConn != sc0 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0)
} }
@ -454,7 +458,9 @@ func (s) TestSubConnToConnectWhenOverallTransientFailure(t *testing.T) {
// ringhash won't tell SCs to connect until there is an RPC, so simulate // ringhash won't tell SCs to connect until there is an RPC, so simulate
// one now. // one now.
p0.Pick(balancer.PickInfo{Ctx: context.Background()}) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
p0.Pick(balancer.PickInfo{Ctx: ctx})
// Turn the first subconn to transient failure. // Turn the first subconn to transient failure.
sc0 := ring0.items[0].sc.sc.(*testutils.TestSubConn) sc0 := ring0.items[0].sc.sc.(*testutils.TestSubConn)

View File

@ -22,6 +22,7 @@ import (
"context" "context"
"regexp" "regexp"
"testing" "testing"
"time"
xxhash "github.com/cespare/xxhash/v2" xxhash "github.com/cespare/xxhash/v2"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -34,6 +35,8 @@ import (
"google.golang.org/grpc/xds/internal/xdsclient/xdsresource" "google.golang.org/grpc/xds/internal/xdsclient/xdsresource"
) )
var defaultTestTimeout = 10 * time.Second
type s struct { type s struct {
grpctest.Tester grpctest.Tester
} }
@ -67,6 +70,8 @@ func (s) TestGenerateRequestHash(t *testing.T) {
channelID: channelID, channelID: channelID,
}, },
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
tests := []struct { tests := []struct {
name string name string
hashPolicies []*xdsresource.HashPolicy hashPolicies []*xdsresource.HashPolicy
@ -85,7 +90,7 @@ func (s) TestGenerateRequestHash(t *testing.T) {
}}, }},
requestHashWant: xxhash.Sum64String("/new-products"), requestHashWant: xxhash.Sum64String("/new-products"),
rpcInfo: iresolver.RPCInfo{ rpcInfo: iresolver.RPCInfo{
Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs(":path", "/products")), Context: metadata.NewOutgoingContext(ctx, metadata.Pairs(":path", "/products")),
Method: "/some-method", Method: "/some-method",
}, },
}, },
@ -113,7 +118,7 @@ func (s) TestGenerateRequestHash(t *testing.T) {
}}, }},
requestHashWant: xxhash.Sum64String("eaebece"), requestHashWant: xxhash.Sum64String("eaebece"),
rpcInfo: iresolver.RPCInfo{ rpcInfo: iresolver.RPCInfo{
Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs(":path", "abc")), Context: metadata.NewOutgoingContext(ctx, metadata.Pairs(":path", "abc")),
Method: "/some-method", Method: "/some-method",
}, },
}, },
@ -128,7 +133,7 @@ func (s) TestGenerateRequestHash(t *testing.T) {
}}, }},
requestHashWant: channelID, requestHashWant: channelID,
rpcInfo: iresolver.RPCInfo{ rpcInfo: iresolver.RPCInfo{
Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("something-bin", "xyz")), Context: metadata.NewOutgoingContext(ctx, metadata.Pairs("something-bin", "xyz")),
}, },
}, },
// Tests that extra metadata takes precedence over the user's metadata. // Tests that extra metadata takes precedence over the user's metadata.
@ -141,7 +146,7 @@ func (s) TestGenerateRequestHash(t *testing.T) {
requestHashWant: xxhash.Sum64String("grpc value"), requestHashWant: xxhash.Sum64String("grpc value"),
rpcInfo: iresolver.RPCInfo{ rpcInfo: iresolver.RPCInfo{
Context: grpcutil.WithExtraMetadata( Context: grpcutil.WithExtraMetadata(
metadata.NewOutgoingContext(context.Background(), metadata.Pairs("content-type", "user value")), metadata.NewOutgoingContext(ctx, metadata.Pairs("content-type", "user value")),
metadata.Pairs("content-type", "grpc value"), metadata.Pairs("content-type", "grpc value"),
), ),
}, },

View File

@ -1436,7 +1436,7 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) {
} }
var doneFunc func() var doneFunc func()
_, err = res.Interceptor.NewStream(context.Background(), iresolver.RPCInfo{}, func() {}, func(ctx context.Context, done func()) (iresolver.ClientStream, error) { _, err = res.Interceptor.NewStream(ctx, iresolver.RPCInfo{}, func() {}, func(ctx context.Context, done func()) (iresolver.ClientStream, error) {
doneFunc = done doneFunc = done
return nil, nil return nil, nil
}) })

View File

@ -25,6 +25,7 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"testing" "testing"
"time"
v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3"
@ -53,6 +54,8 @@ const (
rLevel = "route level" rLevel = "route level"
) )
var defaultTestTimeout = 10 * time.Second
func emptyValidNetworkFilters(t *testing.T) []*v3listenerpb.Filter { func emptyValidNetworkFilters(t *testing.T) []*v3listenerpb.Filter {
return []*v3listenerpb.Filter{ return []*v3listenerpb.Filter{
{ {
@ -2912,6 +2915,8 @@ func (s) TestHTTPFilterInstantiation(t *testing.T) {
wantErrs: []string{topLevel, vhLevel, rLevel, rLevel, rLevel, vhLevel}, wantErrs: []string{topLevel, vhLevel, rLevel, rLevel, rLevel, vhLevel},
}, },
} }
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
fc := FilterChain{ fc := FilterChain{
@ -2927,7 +2932,7 @@ func (s) TestHTTPFilterInstantiation(t *testing.T) {
for _, vh := range urc.VHS { for _, vh := range urc.VHS {
for _, r := range vh.Routes { for _, r := range vh.Routes {
for _, int := range r.Interceptors { for _, int := range r.Interceptors {
errs = append(errs, int.AllowRPC(context.Background()).Error()) errs = append(errs, int.AllowRPC(ctx).Error())
} }
} }
} }

View File

@ -31,6 +31,8 @@ import (
) )
func (s) TestAndMatcherMatch(t *testing.T) { func (s) TestAndMatcherMatch(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
tests := []struct { tests := []struct {
name string name string
pm pathMatcher pm pathMatcher
@ -44,7 +46,7 @@ func (s) TestAndMatcherMatch(t *testing.T) {
hm: matcher.NewHeaderExactMatcher("th", "tv", false), hm: matcher.NewHeaderExactMatcher("th", "tv", false),
info: iresolver.RPCInfo{ info: iresolver.RPCInfo{
Method: "/a/b", Method: "/a/b",
Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), Context: metadata.NewOutgoingContext(ctx, metadata.Pairs("th", "tv")),
}, },
want: true, want: true,
}, },
@ -54,7 +56,7 @@ func (s) TestAndMatcherMatch(t *testing.T) {
hm: matcher.NewHeaderExactMatcher("th", "tv", false), hm: matcher.NewHeaderExactMatcher("th", "tv", false),
info: iresolver.RPCInfo{ info: iresolver.RPCInfo{
Method: "/a/b", Method: "/a/b",
Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), Context: metadata.NewOutgoingContext(ctx, metadata.Pairs("th", "tv")),
}, },
want: true, want: true,
}, },
@ -64,7 +66,7 @@ func (s) TestAndMatcherMatch(t *testing.T) {
hm: matcher.NewHeaderExactMatcher("th", "tv", false), hm: matcher.NewHeaderExactMatcher("th", "tv", false),
info: iresolver.RPCInfo{ info: iresolver.RPCInfo{
Method: "/z/y", Method: "/z/y",
Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), Context: metadata.NewOutgoingContext(ctx, metadata.Pairs("th", "tv")),
}, },
want: false, want: false,
}, },
@ -74,7 +76,7 @@ func (s) TestAndMatcherMatch(t *testing.T) {
hm: matcher.NewHeaderExactMatcher("th", "abc", false), hm: matcher.NewHeaderExactMatcher("th", "abc", false),
info: iresolver.RPCInfo{ info: iresolver.RPCInfo{
Method: "/a/b", Method: "/a/b",
Context: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("th", "tv")), Context: metadata.NewOutgoingContext(ctx, metadata.Pairs("th", "tv")),
}, },
want: false, want: false,
}, },
@ -84,7 +86,7 @@ func (s) TestAndMatcherMatch(t *testing.T) {
hm: matcher.NewHeaderExactMatcher("content-type", "fake", false), hm: matcher.NewHeaderExactMatcher("content-type", "fake", false),
info: iresolver.RPCInfo{ info: iresolver.RPCInfo{
Method: "/a/b", Method: "/a/b",
Context: grpcutil.WithExtraMetadata(context.Background(), metadata.Pairs( Context: grpcutil.WithExtraMetadata(ctx, metadata.Pairs(
"content-type", "fake", "content-type", "fake",
)), )),
}, },
@ -97,7 +99,7 @@ func (s) TestAndMatcherMatch(t *testing.T) {
info: iresolver.RPCInfo{ info: iresolver.RPCInfo{
Method: "/a/b", Method: "/a/b",
Context: grpcutil.WithExtraMetadata( Context: grpcutil.WithExtraMetadata(
metadata.NewOutgoingContext(context.Background(), metadata.Pairs("t-bin", "123")), metadata.Pairs( metadata.NewOutgoingContext(ctx, metadata.Pairs("t-bin", "123")), metadata.Pairs(
"content-type", "fake", "content-type", "fake",
)), )),
}, },