clusterresolver: Avoid blocking for subsequent resolver updates in test (#7937)

This commit is contained in:
Arjan Singh Bal 2025-01-10 09:44:27 +05:30 committed by GitHub
parent 9223fd6115
commit 62b4867888
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 10 additions and 14 deletions

View File

@ -20,8 +20,8 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -1160,17 +1160,17 @@ func (s) TestEDS_EndpointWithMultipleAddresses(t *testing.T) {
return &testpb.SimpleResponse{}, nil return &testpb.SimpleResponse{}, nil
}, },
} }
lis1, err := net.Listen("tcp", "localhost:0") lis1, err := testutils.LocalTCPListener()
if err != nil { if err != nil {
t.Fatalf("Failed to create listener: %v", err) t.Fatalf("Failed to create listener: %v", err)
} }
defer lis1.Close() defer lis1.Close()
lis2, err := net.Listen("tcp", "localhost:0") lis2, err := testutils.LocalTCPListener()
if err != nil { if err != nil {
t.Fatalf("Failed to create listener: %v", err) t.Fatalf("Failed to create listener: %v", err)
} }
defer lis2.Close() defer lis2.Close()
lis3, err := net.Listen("tcp", "localhost:0") lis3, err := testutils.LocalTCPListener()
if err != nil { if err != nil {
t.Fatalf("Failed to create listener: %v", err) t.Fatalf("Failed to create listener: %v", err)
} }
@ -1223,7 +1223,8 @@ func (s) TestEDS_EndpointWithMultipleAddresses(t *testing.T) {
defer func() { defer func() {
balancer.Register(originalRRBuilder) balancer.Register(originalRRBuilder)
}() }()
resolverUpdateCh := make(chan resolver.State, 1) resolverState := atomic.Pointer[resolver.State]{}
resolverState.Store(&resolver.State{})
stub.Register(roundrobin.Name, stub.BalancerFuncs{ stub.Register(roundrobin.Name, stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) { Init: func(bd *stub.BalancerData) {
bd.Data = originalRRBuilder.Build(bd.ClientConn, bd.BuildOptions) bd.Data = originalRRBuilder.Build(bd.ClientConn, bd.BuildOptions)
@ -1232,7 +1233,7 @@ func (s) TestEDS_EndpointWithMultipleAddresses(t *testing.T) {
bd.Data.(balancer.Balancer).Close() bd.Data.(balancer.Balancer).Close()
}, },
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
resolverUpdateCh <- ccs.ResolverState resolverState.Store(&ccs.ResolverState)
return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs) return bd.Data.(balancer.Balancer).UpdateClientConnState(ccs)
}, },
}) })
@ -1297,15 +1298,10 @@ func (s) TestEDS_EndpointWithMultipleAddresses(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
var rs resolver.State gotState := resolverState.Load()
select {
case rs = <-resolverUpdateCh:
case <-ctx.Done():
t.Fatalf("Context timed out waiting for resolver update.")
}
gotEndpointPorts := []uint32{} gotEndpointPorts := []uint32{}
for _, a := range rs.Endpoints[0].Addresses { for _, a := range gotState.Endpoints[0].Addresses {
gotEndpointPorts = append(gotEndpointPorts, testutils.ParsePort(t, a.Addr)) gotEndpointPorts = append(gotEndpointPorts, testutils.ParsePort(t, a.Addr))
} }
if diff := cmp.Diff(gotEndpointPorts, tc.wantEndpointPorts); diff != "" { if diff := cmp.Diff(gotEndpointPorts, tc.wantEndpointPorts); diff != "" {
@ -1313,7 +1309,7 @@ func (s) TestEDS_EndpointWithMultipleAddresses(t *testing.T) {
} }
gotAddrPorts := []uint32{} gotAddrPorts := []uint32{}
for _, a := range rs.Addresses { for _, a := range gotState.Addresses {
gotAddrPorts = append(gotAddrPorts, testutils.ParsePort(t, a.Addr)) gotAddrPorts = append(gotAddrPorts, testutils.ParsePort(t, a.Addr))
} }
if diff := cmp.Diff(gotAddrPorts, tc.wantAddrPorts); diff != "" { if diff := cmp.Diff(gotAddrPorts, tc.wantAddrPorts); diff != "" {