pickfirst: New pick first policy for dualstack (#7498)

This commit is contained in:
Arjan Singh Bal 2024-10-10 09:33:47 +05:30 committed by GitHub
parent 18a4eacc06
commit 00b9e140ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 2048 additions and 11 deletions

View File

@ -19,6 +19,9 @@ jobs:
- name: Run coverage
run: go test -coverprofile=coverage.out -coverpkg=./... ./...
- name: Run coverage with new pickfirst
run: GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST=true go test -coverprofile=coverage_new_pickfirst.out -coverpkg=./... ./...
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
with:

View File

@ -70,6 +70,11 @@ jobs:
- type: tests
goversion: '1.21'
- type: tests
goversion: '1.22'
testflags: -race
grpcenv: 'GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST=true'
steps:
# Setup the environment.
- name: Setup GOARCH

View File

@ -29,13 +29,19 @@ import (
"google.golang.org/grpc/balancer/pickfirst/internal"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/envconfig"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
_ "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" // For automatically registering the new pickfirst if required.
)
func init() {
if envconfig.NewPickFirstEnabled {
return
}
balancer.Register(pickfirstBuilder{})
}

View File

@ -0,0 +1,132 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pickfirst
import (
"context"
"errors"
"fmt"
"testing"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver"
)
const (
// Default timeout for tests in this package.
defaultTestTimeout = 10 * time.Second
// Default short timeout, to be used when waiting for events which are not
// expected to happen.
defaultTestShortTimeout = 100 * time.Millisecond
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// TestPickFirstLeaf_InitialResolverError sends a resolver error to the balancer
// before a valid resolver update. It verifies that the clientconn state is
// updated to TRANSIENT_FAILURE.
func (s) TestPickFirstLeaf_InitialResolverError(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cc := testutils.NewBalancerClientConn(t)
bal := pickfirstBuilder{}.Build(cc, balancer.BuildOptions{})
defer bal.Close()
bal.ResolverError(errors.New("resolution failed: test error"))
if err := cc.WaitForConnectivityState(ctx, connectivity.TransientFailure); err != nil {
t.Fatalf("cc.WaitForConnectivityState(%v) returned error: %v", connectivity.TransientFailure, err)
}
// After sending a valid update, the LB policy should report CONNECTING.
ccState := balancer.ClientConnState{
ResolverState: resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: "1.1.1.1:1"}}},
{Addresses: []resolver.Address{{Addr: "2.2.2.2:2"}}},
},
},
}
if err := bal.UpdateClientConnState(ccState); err != nil {
t.Fatalf("UpdateClientConnState(%v) returned error: %v", ccState, err)
}
if err := cc.WaitForConnectivityState(ctx, connectivity.Connecting); err != nil {
t.Fatalf("cc.WaitForConnectivityState(%v) returned error: %v", connectivity.Connecting, err)
}
}
// TestPickFirstLeaf_ResolverErrorinTF sends a resolver error to the balancer
// before when it's attempting to connect to a SubConn TRANSIENT_FAILURE. It
// verifies that the picker is updated and the SubConn is not closed.
func (s) TestPickFirstLeaf_ResolverErrorinTF(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cc := testutils.NewBalancerClientConn(t)
bal := pickfirstBuilder{}.Build(cc, balancer.BuildOptions{})
defer bal.Close()
// After sending a valid update, the LB policy should report CONNECTING.
ccState := balancer.ClientConnState{
ResolverState: resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: "1.1.1.1:1"}}},
},
},
}
if err := bal.UpdateClientConnState(ccState); err != nil {
t.Fatalf("UpdateClientConnState(%v) returned error: %v", ccState, err)
}
sc1 := <-cc.NewSubConnCh
if err := cc.WaitForConnectivityState(ctx, connectivity.Connecting); err != nil {
t.Fatalf("cc.WaitForConnectivityState(%v) returned error: %v", connectivity.Connecting, err)
}
scErr := fmt.Errorf("test error: connection refused")
sc1.UpdateState(balancer.SubConnState{
ConnectivityState: connectivity.TransientFailure,
ConnectionError: scErr,
})
if err := cc.WaitForPickerWithErr(ctx, scErr); err != nil {
t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", scErr, err)
}
bal.ResolverError(errors.New("resolution failed: test error"))
if err := cc.WaitForErrPicker(ctx); err != nil {
t.Fatalf("cc.WaitForPickerWithErr() returned error: %v", err)
}
select {
case <-time.After(defaultTestShortTimeout):
case sc := <-cc.ShutdownSubConnCh:
t.Fatalf("Unexpected SubConn shutdown: %v", sc)
}
}

View File

@ -0,0 +1,624 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package pickfirstleaf contains the pick_first load balancing policy which
// will be the universal leaf policy after dualstack changes are implemented.
//
// # Experimental
//
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
// later release.
package pickfirstleaf
import (
"encoding/json"
"errors"
"fmt"
"sync"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/pickfirst/internal"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/envconfig"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
func init() {
if envconfig.NewPickFirstEnabled {
// Register as the default pick_first balancer.
Name = "pick_first"
}
balancer.Register(pickfirstBuilder{})
}
var (
logger = grpclog.Component("pick-first-leaf-lb")
// Name is the name of the pick_first_leaf balancer.
// It is changed to "pick_first" in init() if this balancer is to be
// registered as the default pickfirst.
Name = "pick_first_leaf"
)
// TODO: change to pick-first when this becomes the default pick_first policy.
const logPrefix = "[pick-first-leaf-lb %p] "
type pickfirstBuilder struct{}
func (pickfirstBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer {
b := &pickfirstBalancer{
cc: cc,
addressList: addressList{},
subConns: resolver.NewAddressMap(),
state: connectivity.Connecting,
mu: sync.Mutex{},
}
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b))
return b
}
func (b pickfirstBuilder) Name() string {
return Name
}
func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
var cfg pfConfig
if err := json.Unmarshal(js, &cfg); err != nil {
return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err)
}
return cfg, nil
}
type pfConfig struct {
serviceconfig.LoadBalancingConfig `json:"-"`
// If set to true, instructs the LB policy to shuffle the order of the list
// of endpoints received from the name resolver before attempting to
// connect to them.
ShuffleAddressList bool `json:"shuffleAddressList"`
}
// scData keeps track of the current state of the subConn.
// It is not safe for concurrent access.
type scData struct {
// The following fields are initialized at build time and read-only after
// that.
subConn balancer.SubConn
addr resolver.Address
state connectivity.State
lastErr error
}
func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) {
sd := &scData{
state: connectivity.Idle,
addr: addr,
}
sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{
StateListener: func(state balancer.SubConnState) {
b.updateSubConnState(sd, state)
},
})
if err != nil {
return nil, err
}
sd.subConn = sc
return sd, nil
}
type pickfirstBalancer struct {
// The following fields are initialized at build time and read-only after
// that and therefore do not need to be guarded by a mutex.
logger *internalgrpclog.PrefixLogger
cc balancer.ClientConn
// The mutex is used to ensure synchronization of updates triggered
// from the idle picker and the already serialized resolver,
// SubConn state updates.
mu sync.Mutex
state connectivity.State
// scData for active subonns mapped by address.
subConns *resolver.AddressMap
addressList addressList
firstPass bool
numTF int
}
// ResolverError is called by the ClientConn when the name resolver produces
// an error or when pickfirst determined the resolver update to be invalid.
func (b *pickfirstBalancer) ResolverError(err error) {
b.mu.Lock()
defer b.mu.Unlock()
b.resolverErrorLocked(err)
}
func (b *pickfirstBalancer) resolverErrorLocked(err error) {
if b.logger.V(2) {
b.logger.Infof("Received error from the name resolver: %v", err)
}
// The picker will not change since the balancer does not currently
// report an error. If the balancer hasn't received a single good resolver
// update yet, transition to TRANSIENT_FAILURE.
if b.state != connectivity.TransientFailure && b.addressList.size() > 0 {
if b.logger.V(2) {
b.logger.Infof("Ignoring resolver error because balancer is using a previous good update.")
}
return
}
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)},
})
}
func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
b.mu.Lock()
defer b.mu.Unlock()
if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 {
// Cleanup state pertaining to the previous resolver state.
// Treat an empty address list like an error by calling b.ResolverError.
b.state = connectivity.TransientFailure
b.closeSubConnsLocked()
b.addressList.updateAddrs(nil)
b.resolverErrorLocked(errors.New("produced zero addresses"))
return balancer.ErrBadResolverState
}
cfg, ok := state.BalancerConfig.(pfConfig)
if state.BalancerConfig != nil && !ok {
return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v: %w", state.BalancerConfig, state.BalancerConfig, balancer.ErrBadResolverState)
}
if b.logger.V(2) {
b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState))
}
var newAddrs []resolver.Address
if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 {
// Perform the optional shuffling described in gRFC A62. The shuffling
// will change the order of endpoints but not touch the order of the
// addresses within each endpoint. - A61
if cfg.ShuffleAddressList {
endpoints = append([]resolver.Endpoint{}, endpoints...)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
// "Flatten the list by concatenating the ordered list of addresses for
// each of the endpoints, in order." - A61
for _, endpoint := range endpoints {
// "In the flattened list, interleave addresses from the two address
// families, as per RFC-8305 section 4." - A61
// TODO: support the above language.
newAddrs = append(newAddrs, endpoint.Addresses...)
}
} else {
// Endpoints not set, process addresses until we migrate resolver
// emissions fully to Endpoints. The top channel does wrap emitted
// addresses with endpoints, however some balancers such as weighted
// target do not forward the corresponding correct endpoints down/split
// endpoints properly. Once all balancers correctly forward endpoints
// down, can delete this else conditional.
newAddrs = state.ResolverState.Addresses
if cfg.ShuffleAddressList {
newAddrs = append([]resolver.Address{}, newAddrs...)
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
}
}
// If an address appears in multiple endpoints or in the same endpoint
// multiple times, we keep it only once. We will create only one SubConn
// for the address because an AddressMap is used to store SubConns.
// Not de-duplicating would result in attempting to connect to the same
// SubConn multiple times in the same pass. We don't want this.
newAddrs = deDupAddresses(newAddrs)
// Since we have a new set of addresses, we are again at first pass.
b.firstPass = true
// If the previous ready SubConn exists in new address list,
// keep this connection and don't create new SubConns.
prevAddr := b.addressList.currentAddress()
prevAddrsCount := b.addressList.size()
b.addressList.updateAddrs(newAddrs)
if b.state == connectivity.Ready && b.addressList.seekTo(prevAddr) {
return nil
}
b.reconcileSubConnsLocked(newAddrs)
// If it's the first resolver update or the balancer was already READY
// (but the new address list does not contain the ready SubConn) or
// CONNECTING, enter CONNECTING.
// We may be in TRANSIENT_FAILURE due to a previous empty address list,
// we should still enter CONNECTING because the sticky TF behaviour
// mentioned in A62 applies only when the TRANSIENT_FAILURE is reported
// due to connectivity failures.
if b.state == connectivity.Ready || b.state == connectivity.Connecting || prevAddrsCount == 0 {
// Start connection attempt at first address.
b.state = connectivity.Connecting
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
b.requestConnectionLocked()
} else if b.state == connectivity.TransientFailure {
// If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until
// we're READY. See A62.
b.requestConnectionLocked()
}
return nil
}
// UpdateSubConnState is unused as a StateListener is always registered when
// creating SubConns.
func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state balancer.SubConnState) {
b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", subConn, state)
}
func (b *pickfirstBalancer) Close() {
b.mu.Lock()
defer b.mu.Unlock()
b.closeSubConnsLocked()
b.state = connectivity.Shutdown
}
// ExitIdle moves the balancer out of idle state. It can be called concurrently
// by the idlePicker and clientConn so access to variables should be
// synchronized.
func (b *pickfirstBalancer) ExitIdle() {
b.mu.Lock()
defer b.mu.Unlock()
if b.state == connectivity.Idle && b.addressList.currentAddress() == b.addressList.first() {
b.firstPass = true
b.requestConnectionLocked()
}
}
func (b *pickfirstBalancer) closeSubConnsLocked() {
for _, sd := range b.subConns.Values() {
sd.(*scData).subConn.Shutdown()
}
b.subConns = resolver.NewAddressMap()
}
// deDupAddresses ensures that each address appears only once in the slice.
func deDupAddresses(addrs []resolver.Address) []resolver.Address {
seenAddrs := resolver.NewAddressMap()
retAddrs := []resolver.Address{}
for _, addr := range addrs {
if _, ok := seenAddrs.Get(addr); ok {
continue
}
retAddrs = append(retAddrs, addr)
}
return retAddrs
}
func (b *pickfirstBalancer) reconcileSubConnsLocked(newAddrs []resolver.Address) {
// Remove old subConns that were not in new address list.
oldAddrsMap := resolver.NewAddressMap()
for _, k := range b.subConns.Keys() {
oldAddrsMap.Set(k, true)
}
// Flatten the new endpoint addresses.
newAddrsMap := resolver.NewAddressMap()
for _, addr := range newAddrs {
newAddrsMap.Set(addr, true)
}
// Shut them down and remove them.
for _, oldAddr := range oldAddrsMap.Keys() {
if _, ok := newAddrsMap.Get(oldAddr); ok {
continue
}
val, _ := b.subConns.Get(oldAddr)
val.(*scData).subConn.Shutdown()
b.subConns.Delete(oldAddr)
}
}
// shutdownRemainingLocked shuts down remaining subConns. Called when a subConn
// becomes ready, which means that all other subConn must be shutdown.
func (b *pickfirstBalancer) shutdownRemainingLocked(selected *scData) {
for _, v := range b.subConns.Values() {
sd := v.(*scData)
if sd.subConn != selected.subConn {
sd.subConn.Shutdown()
}
}
b.subConns = resolver.NewAddressMap()
b.subConns.Set(selected.addr, selected)
}
// requestConnectionLocked starts connecting on the subchannel corresponding to
// the current address. If no subchannel exists, one is created. If the current
// subchannel is in TransientFailure, a connection to the next address is
// attempted until a subchannel is found.
func (b *pickfirstBalancer) requestConnectionLocked() {
if !b.addressList.isValid() {
return
}
var lastErr error
for valid := true; valid; valid = b.addressList.increment() {
curAddr := b.addressList.currentAddress()
sd, ok := b.subConns.Get(curAddr)
if !ok {
var err error
// We want to assign the new scData to sd from the outer scope,
// hence we can't use := below.
sd, err = b.newSCData(curAddr)
if err != nil {
// This should never happen, unless the clientConn is being shut
// down.
if b.logger.V(2) {
b.logger.Infof("Failed to create a subConn for address %v: %v", curAddr.String(), err)
}
// Do nothing, the LB policy will be closed soon.
return
}
b.subConns.Set(curAddr, sd)
}
scd := sd.(*scData)
switch scd.state {
case connectivity.Idle:
scd.subConn.Connect()
case connectivity.TransientFailure:
// Try the next address.
lastErr = scd.lastErr
continue
case connectivity.Ready:
// Should never happen.
b.logger.Errorf("Requesting a connection even though we have a READY SubConn")
case connectivity.Shutdown:
// Should never happen.
b.logger.Errorf("SubConn with state SHUTDOWN present in SubConns map")
case connectivity.Connecting:
// Wait for the SubConn to report success or failure.
}
return
}
// All the remaining addresses in the list are in TRANSIENT_FAILURE, end the
// first pass.
b.endFirstPassLocked(lastErr)
}
func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.SubConnState) {
b.mu.Lock()
defer b.mu.Unlock()
oldState := sd.state
sd.state = newState.ConnectivityState
// Previously relevant SubConns can still callback with state updates.
// To prevent pickers from returning these obsolete SubConns, this logic
// is included to check if the current list of active SubConns includes this
// SubConn.
if activeSD, found := b.subConns.Get(sd.addr); !found || activeSD != sd {
return
}
if newState.ConnectivityState == connectivity.Shutdown {
return
}
if newState.ConnectivityState == connectivity.Ready {
b.shutdownRemainingLocked(sd)
if !b.addressList.seekTo(sd.addr) {
// This should not fail as we should have only one SubConn after
// entering READY. The SubConn should be present in the addressList.
b.logger.Errorf("Address %q not found address list in %v", sd.addr, b.addressList.addresses)
return
}
b.state = connectivity.Ready
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.Ready,
Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}},
})
return
}
// If the LB policy is READY, and it receives a subchannel state change,
// it means that the READY subchannel has failed.
// A SubConn can also transition from CONNECTING directly to IDLE when
// a transport is successfully created, but the connection fails
// before the SubConn can send the notification for READY. We treat
// this as a successful connection and transition to IDLE.
if (b.state == connectivity.Ready && newState.ConnectivityState != connectivity.Ready) || (oldState == connectivity.Connecting && newState.ConnectivityState == connectivity.Idle) {
// Once a transport fails, the balancer enters IDLE and starts from
// the first address when the picker is used.
b.shutdownRemainingLocked(sd)
b.state = connectivity.Idle
b.addressList.reset()
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.Idle,
Picker: &idlePicker{exitIdle: sync.OnceFunc(b.ExitIdle)},
})
return
}
if b.firstPass {
switch newState.ConnectivityState {
case connectivity.Connecting:
// The balancer can be in either IDLE, CONNECTING or
// TRANSIENT_FAILURE. If it's in TRANSIENT_FAILURE, stay in
// TRANSIENT_FAILURE until it's READY. See A62.
// If the balancer is already in CONNECTING, no update is needed.
if b.state == connectivity.Idle {
b.state = connectivity.Connecting
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.Connecting,
Picker: &picker{err: balancer.ErrNoSubConnAvailable},
})
}
case connectivity.TransientFailure:
sd.lastErr = newState.ConnectionError
// Since we're re-using common SubConns while handling resolver
// updates, we could receive an out of turn TRANSIENT_FAILURE from
// a pass over the previous address list. We ignore such updates.
if curAddr := b.addressList.currentAddress(); !equalAddressIgnoringBalAttributes(&curAddr, &sd.addr) {
return
}
if b.addressList.increment() {
b.requestConnectionLocked()
return
}
// End of the first pass.
b.endFirstPassLocked(newState.ConnectionError)
}
return
}
// We have finished the first pass, keep re-connecting failing SubConns.
switch newState.ConnectivityState {
case connectivity.TransientFailure:
b.numTF = (b.numTF + 1) % b.subConns.Len()
sd.lastErr = newState.ConnectionError
if b.numTF%b.subConns.Len() == 0 {
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: newState.ConnectionError},
})
}
// We don't need to request re-resolution since the SubConn already
// does that before reporting TRANSIENT_FAILURE.
// TODO: #7534 - Move re-resolution requests from SubConn into
// pick_first.
case connectivity.Idle:
sd.subConn.Connect()
}
}
func (b *pickfirstBalancer) endFirstPassLocked(lastErr error) {
b.firstPass = false
b.numTF = 0
b.state = connectivity.TransientFailure
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: lastErr},
})
// Start re-connecting all the SubConns that are already in IDLE.
for _, v := range b.subConns.Values() {
sd := v.(*scData)
if sd.state == connectivity.Idle {
sd.subConn.Connect()
}
}
}
type picker struct {
result balancer.PickResult
err error
}
func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
return p.result, p.err
}
// idlePicker is used when the SubConn is IDLE and kicks the SubConn into
// CONNECTING when Pick is called.
type idlePicker struct {
exitIdle func()
}
func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
i.exitIdle()
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}
// addressList manages sequentially iterating over addresses present in a list
// of endpoints. It provides a 1 dimensional view of the addresses present in
// the endpoints.
// This type is not safe for concurrent access.
type addressList struct {
addresses []resolver.Address
idx int
}
func (al *addressList) isValid() bool {
return al.idx < len(al.addresses)
}
func (al *addressList) size() int {
return len(al.addresses)
}
// increment moves to the next index in the address list.
// This method returns false if it went off the list, true otherwise.
func (al *addressList) increment() bool {
if !al.isValid() {
return false
}
al.idx++
return al.idx < len(al.addresses)
}
// currentAddress returns the current address pointed to in the addressList.
// If the list is in an invalid state, it returns an empty address instead.
func (al *addressList) currentAddress() resolver.Address {
if !al.isValid() {
return resolver.Address{}
}
return al.addresses[al.idx]
}
// first returns the first address in the list. If the list is empty, it returns
// an empty address instead.
func (al *addressList) first() resolver.Address {
if len(al.addresses) == 0 {
return resolver.Address{}
}
return al.addresses[0]
}
func (al *addressList) reset() {
al.idx = 0
}
func (al *addressList) updateAddrs(addrs []resolver.Address) {
al.addresses = addrs
al.reset()
}
// seekTo returns false if the needle was not found and the current index was
// left unchanged.
func (al *addressList) seekTo(needle resolver.Address) bool {
for ai, addr := range al.addresses {
if !equalAddressIgnoringBalAttributes(&addr, &needle) {
continue
}
al.idx = ai
return true
}
return false
}
// equalAddressIgnoringBalAttributes returns true is a and b are considered
// equal. This is different from the Equal method on the resolver.Address type
// which considers all fields to determine equality. Here, we only consider
// fields that are meaningful to the SubConn.
func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool {
return a.Addr == b.Addr && a.ServerName == b.ServerName &&
a.Attributes.Equal(b.Attributes) &&
a.Metadata == b.Metadata
}

View File

@ -0,0 +1,957 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pickfirstleaf_test
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/testutils/pickfirst"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/status"
testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)
const (
// Default timeout for tests in this package.
defaultTestTimeout = 10 * time.Second
// Default short timeout, to be used when waiting for events which are not
// expected to happen.
defaultTestShortTimeout = 100 * time.Millisecond
stateStoringBalancerName = "state_storing"
)
var stateStoringServiceConfig = fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateStoringBalancerName)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// setupPickFirstLeaf performs steps required for pick_first tests. It starts a
// bunch of backends exporting the TestService, creates a ClientConn to them
// with service config specifying the use of the state_storing LB policy.
func setupPickFirstLeaf(t *testing.T, backendCount int, opts ...grpc.DialOption) (*grpc.ClientConn, *manual.Resolver, *backendManager) {
t.Helper()
r := manual.NewBuilderWithScheme("whatever")
backends := make([]*stubserver.StubServer, backendCount)
addrs := make([]resolver.Address, backendCount)
for i := 0; i < backendCount; i++ {
backend := stubserver.StartTestService(t, nil)
t.Cleanup(func() {
backend.Stop()
})
backends[i] = backend
addrs[i] = resolver.Address{Addr: backend.Address}
}
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(stateStoringServiceConfig),
}
dopts = append(dopts, opts...)
cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatalf("grpc.NewClient() failed: %v", err)
}
t.Cleanup(func() { cc.Close() })
// At this point, the resolver has not returned any addresses to the channel.
// This RPC must block until the context expires.
sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
defer sCancel()
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(sCtx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() = %s, want %s", status.Code(err), codes.DeadlineExceeded)
}
return cc, r, &backendManager{backends}
}
// TestPickFirstLeaf_SimpleResolverUpdate tests the behaviour of the pick first
// policy when given an list of addresses. The following steps are carried
// out in order:
// 1. A list of addresses are given through the resolver. Only one
// of the servers is running.
// 2. RPCs are sent to verify they reach the running server.
//
// The state transitions of the ClientConn and all the subconns created are
// verified.
func (s) TestPickFirstLeaf_SimpleResolverUpdate_FirstServerReady(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balCh := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balCh})
cc, r, bm := setupPickFirstLeaf(t, 2)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
r.UpdateState(resolver.State{Addresses: addrs})
var bal *stateStoringBalancer
select {
case bal = <-balCh:
case <-ctx.Done():
t.Fatal("Context expired while waiting for balancer to be built")
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
wantSCStates := []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
wantConnStateTransitions := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
func (s) TestPickFirstLeaf_SimpleResolverUpdate_FirstServerUnReady(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balCh := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balCh})
cc, r, bm := setupPickFirstLeaf(t, 2)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
bm.stopAllExcept(1)
r.UpdateState(resolver.State{Addresses: addrs})
var bal *stateStoringBalancer
select {
case bal = <-balCh:
case <-ctx.Done():
t.Fatal("Context expired while waiting for balancer to be built")
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates := []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
wantConnStateTransitions := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
func (s) TestPickFirstLeaf_SimpleResolverUpdate_DuplicateAddrs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balCh := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balCh})
cc, r, bm := setupPickFirstLeaf(t, 2)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
bm.stopAllExcept(1)
// Add a duplicate entry in the addresslist
r.UpdateState(resolver.State{
Addresses: append([]resolver.Address{addrs[0]}, addrs...),
})
var bal *stateStoringBalancer
select {
case bal = <-balCh:
case <-ctx.Done():
t.Fatal("Context expired while waiting for balancer to be built")
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates := []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
wantConnStateTransitions := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
// TestPickFirstLeaf_ResolverUpdates_DisjointLists tests the behaviour of the pick first
// policy when the following steps are carried out in order:
// 1. A list of addresses are given through the resolver. Only one
// of the servers is running.
// 2. RPCs are sent to verify they reach the running server.
// 3. A second resolver update is sent. Again, only one of the servers is
// running. This may not be the same server as before.
// 4. RPCs are sent to verify they reach the running server.
//
// The state transitions of the ClientConn and all the subconns created are
// verified.
func (s) TestPickFirstLeaf_ResolverUpdates_DisjointLists(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balCh := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balCh})
cc, r, bm := setupPickFirstLeaf(t, 4)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
bm.backends[0].S.Stop()
bm.backends[0].S = nil
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}})
var bal *stateStoringBalancer
select {
case bal = <-balCh:
case <-ctx.Done():
t.Fatal("Context expired while waiting for balancer to be built")
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates := []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
bm.backends[2].S.Stop()
bm.backends[2].S = nil
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[2], addrs[3]}})
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[3]); err != nil {
t.Fatal(err)
}
wantSCStates = []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[2]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[3]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
wantConnStateTransitions := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
func (s) TestPickFirstLeaf_ResolverUpdates_ActiveBackendInUpdatedList(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balCh := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balCh})
cc, r, bm := setupPickFirstLeaf(t, 3)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
bm.backends[0].S.Stop()
bm.backends[0].S = nil
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}})
var bal *stateStoringBalancer
select {
case bal = <-balCh:
case <-ctx.Done():
t.Fatal("Context expired while waiting for balancer to be built")
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates := []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
bm.backends[2].S.Stop()
bm.backends[2].S = nil
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[2], addrs[1]}})
// Verify that the ClientConn stays in READY.
sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer sCancel()
testutils.AwaitNoStateChange(sCtx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates = []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
wantConnStateTransitions := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
func (s) TestPickFirstLeaf_ResolverUpdates_InActiveBackendInUpdatedList(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balCh := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balCh})
cc, r, bm := setupPickFirstLeaf(t, 3)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
bm.backends[0].S.Stop()
bm.backends[0].S = nil
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}})
var bal *stateStoringBalancer
select {
case bal = <-balCh:
case <-ctx.Done():
t.Fatal("Context expired while waiting for balancer to be built")
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates := []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
bm.backends[2].S.Stop()
bm.backends[2].S = nil
if err := bm.backends[0].StartServer(); err != nil {
t.Fatalf("Failed to re-start test backend: %v", err)
}
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[2]}})
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
wantSCStates = []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
wantConnStateTransitions := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
func (s) TestPickFirstLeaf_ResolverUpdates_IdenticalLists(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balCh := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balCh})
cc, r, bm := setupPickFirstLeaf(t, 2)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
bm.backends[0].S.Stop()
bm.backends[0].S = nil
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}})
var bal *stateStoringBalancer
select {
case bal = <-balCh:
case <-ctx.Done():
t.Fatal("Context expired while waiting for balancer to be built")
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates := []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}})
// Verify that the ClientConn stays in READY.
sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer sCancel()
testutils.AwaitNoStateChange(sCtx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates = []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
wantConnStateTransitions := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
// TestPickFirstLeaf_StopConnectedServer tests the behaviour of the pick first
// policy when the connected server is shut down. It carries out the following
// steps in order:
// 1. A list of addresses are given through the resolver. Only one
// of the servers is running.
// 2. The running server is stopped, causing the ClientConn to enter IDLE.
// 3. A (possibly different) server is started.
// 4. RPCs are made to kick the ClientConn out of IDLE. The test verifies that
// the RPCs reach the running server.
//
// The test verifies the ClientConn state transitions.
func (s) TestPickFirstLeaf_StopConnectedServer_FirstServerRestart(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balCh := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balCh})
cc, r, bm := setupPickFirstLeaf(t, 2)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
// shutdown all active backends except the target.
bm.stopAllExcept(0)
r.UpdateState(resolver.State{Addresses: addrs})
var bal *stateStoringBalancer
select {
case bal = <-balCh:
case <-ctx.Done():
t.Fatal("Context expired while waiting for balancer to be built")
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
wantSCStates := []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
// Shut down the connected server.
bm.backends[0].S.Stop()
bm.backends[0].S = nil
testutils.AwaitState(ctx, t, cc, connectivity.Idle)
// Start the new target server.
if err := bm.backends[0].StartServer(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
wantConnStateTransitions := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
connectivity.Idle,
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
func (s) TestPickFirstLeaf_StopConnectedServer_SecondServerRestart(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balCh := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balCh})
cc, r, bm := setupPickFirstLeaf(t, 2)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
// shutdown all active backends except the target.
bm.stopAllExcept(1)
r.UpdateState(resolver.State{Addresses: addrs})
var bal *stateStoringBalancer
select {
case bal = <-balCh:
case <-ctx.Done():
t.Fatal("Context expired while waiting for balancer to be built")
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates := []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
// Shut down the connected server.
bm.backends[1].S.Stop()
bm.backends[1].S = nil
testutils.AwaitState(ctx, t, cc, connectivity.Idle)
// Start the new target server.
if err := bm.backends[1].StartServer(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates = []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
wantConnStateTransitions := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
connectivity.Idle,
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
func (s) TestPickFirstLeaf_StopConnectedServer_SecondServerToFirst(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balCh := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balCh})
cc, r, bm := setupPickFirstLeaf(t, 2)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
// shutdown all active backends except the target.
bm.stopAllExcept(1)
r.UpdateState(resolver.State{Addresses: addrs})
var bal *stateStoringBalancer
select {
case bal = <-balCh:
case <-ctx.Done():
t.Fatal("Context expired while waiting for balancer to be built")
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates := []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
// Shut down the connected server.
bm.backends[1].S.Stop()
bm.backends[1].S = nil
testutils.AwaitState(ctx, t, cc, connectivity.Idle)
// Start the new target server.
if err := bm.backends[0].StartServer(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
wantSCStates = []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
wantConnStateTransitions := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
connectivity.Idle,
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
func (s) TestPickFirstLeaf_StopConnectedServer_FirstServerToSecond(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balCh := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balCh})
cc, r, bm := setupPickFirstLeaf(t, 2)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
// shutdown all active backends except the target.
bm.stopAllExcept(0)
r.UpdateState(resolver.State{Addresses: addrs})
var bal *stateStoringBalancer
select {
case bal = <-balCh:
case <-ctx.Done():
t.Fatal("Context expired while waiting for balancer to be built")
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
wantSCStates := []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
// Shut down the connected server.
bm.backends[0].S.Stop()
bm.backends[0].S = nil
testutils.AwaitState(ctx, t, cc, connectivity.Idle)
// Start the new target server.
if err := bm.backends[1].StartServer(); err != nil {
t.Fatalf("Failed to start server: %v", err)
}
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[1]); err != nil {
t.Fatal(err)
}
wantSCStates = []scState{
{Addrs: []resolver.Address{addrs[0]}, State: connectivity.Shutdown},
{Addrs: []resolver.Address{addrs[1]}, State: connectivity.Ready},
}
if diff := cmp.Diff(wantSCStates, bal.subConnStates()); diff != "" {
t.Errorf("subconn states mismatch (-want +got):\n%s", diff)
}
wantConnStateTransitions := []connectivity.State{
connectivity.Connecting,
connectivity.Ready,
connectivity.Idle,
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantConnStateTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
// TestPickFirstLeaf_EmptyAddressList carries out the following steps in order:
// 1. Send a resolver update with one running backend.
// 2. Send an empty address list causing the balancer to enter TRANSIENT_FAILURE.
// 3. Send a resolver update with one running backend.
// The test verifies the ClientConn state transitions.
func (s) TestPickFirstLeaf_EmptyAddressList(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
balChan := make(chan *stateStoringBalancer, 1)
balancer.Register(&stateStoringBalancerBuilder{balancer: balChan})
cc, r, bm := setupPickFirstLeaf(t, 1)
addrs := bm.resolverAddrs()
stateSubscriber := &ccStateSubscriber{}
internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(cc, stateSubscriber)
r.UpdateState(resolver.State{Addresses: addrs})
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
r.UpdateState(resolver.State{})
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
// The balancer should have entered transient failure.
// It should transition to CONNECTING from TRANSIENT_FAILURE as sticky TF
// only applies when the initial TF is reported due to connection failures
// and not bad resolver states.
r.UpdateState(resolver.State{Addresses: addrs})
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
t.Fatal(err)
}
wantTransitions := []connectivity.State{
// From first resolver update.
connectivity.Connecting,
connectivity.Ready,
// From second update.
connectivity.TransientFailure,
// From third update.
connectivity.Connecting,
connectivity.Ready,
}
if diff := cmp.Diff(wantTransitions, stateSubscriber.transitions); diff != "" {
t.Errorf("ClientConn states mismatch (-want +got):\n%s", diff)
}
}
// stateStoringBalancer stores the state of the subconns being created.
type stateStoringBalancer struct {
balancer.Balancer
mu sync.Mutex
scStates []*scState
}
func (b *stateStoringBalancer) Close() {
b.Balancer.Close()
}
func (b *stateStoringBalancer) ExitIdle() {
if ib, ok := b.Balancer.(balancer.ExitIdler); ok {
ib.ExitIdle()
}
}
type stateStoringBalancerBuilder struct {
balancer chan *stateStoringBalancer
}
func (b *stateStoringBalancerBuilder) Name() string {
return stateStoringBalancerName
}
func (b *stateStoringBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
bal := &stateStoringBalancer{}
bal.Balancer = balancer.Get(pickfirstleaf.Name).Build(&stateStoringCCWrapper{cc, bal}, opts)
b.balancer <- bal
return bal
}
func (b *stateStoringBalancer) subConnStates() []scState {
b.mu.Lock()
defer b.mu.Unlock()
ret := []scState{}
for _, s := range b.scStates {
ret = append(ret, *s)
}
return ret
}
func (b *stateStoringBalancer) addSCState(state *scState) {
b.mu.Lock()
b.scStates = append(b.scStates, state)
b.mu.Unlock()
}
type stateStoringCCWrapper struct {
balancer.ClientConn
b *stateStoringBalancer
}
func (ccw *stateStoringCCWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
oldListener := opts.StateListener
scs := &scState{
State: connectivity.Idle,
Addrs: addrs,
}
ccw.b.addSCState(scs)
opts.StateListener = func(s balancer.SubConnState) {
ccw.b.mu.Lock()
scs.State = s.ConnectivityState
ccw.b.mu.Unlock()
oldListener(s)
}
return ccw.ClientConn.NewSubConn(addrs, opts)
}
type scState struct {
State connectivity.State
Addrs []resolver.Address
}
type backendManager struct {
backends []*stubserver.StubServer
}
func (b *backendManager) stopAllExcept(index int) {
for idx, b := range b.backends {
if idx != index {
b.S.Stop()
b.S = nil
}
}
}
// resolverAddrs returns a list of resolver addresses for the stub server
// backends. Useful when pushing addresses to the manual resolver.
func (b *backendManager) resolverAddrs() []resolver.Address {
addrs := make([]resolver.Address, len(b.backends))
for i, backend := range b.backends {
addrs[i] = resolver.Address{Addr: backend.Address}
}
return addrs
}
type ccStateSubscriber struct {
transitions []connectivity.State
}
func (c *ccStateSubscriber) OnMessage(msg any) {
c.transitions = append(c.transitions, msg.(connectivity.State))
}

View File

@ -0,0 +1,259 @@
/*
*
* Copyright 2024 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pickfirstleaf
import (
"context"
"fmt"
"testing"
"time"
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver"
)
const (
// Default timeout for tests in this package.
defaultTestTimeout = 10 * time.Second
// Default short timeout, to be used when waiting for events which are not
// expected to happen.
defaultTestShortTimeout = 100 * time.Millisecond
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// TestAddressList_Iteration verifies the behaviour of the addressList while
// iterating through the entries.
func (s) TestAddressList_Iteration(t *testing.T) {
addrs := []resolver.Address{
{
Addr: "192.168.1.1",
ServerName: "test-host-1",
Attributes: attributes.New("key-1", "val-1"),
BalancerAttributes: attributes.New("bal-key-1", "bal-val-1"),
},
{
Addr: "192.168.1.2",
ServerName: "test-host-2",
Attributes: attributes.New("key-2", "val-2"),
BalancerAttributes: attributes.New("bal-key-2", "bal-val-2"),
},
{
Addr: "192.168.1.3",
ServerName: "test-host-3",
Attributes: attributes.New("key-3", "val-3"),
BalancerAttributes: attributes.New("bal-key-3", "bal-val-3"),
},
}
addressList := addressList{}
emptyAddress := resolver.Address{}
if got, want := addressList.first(), emptyAddress; got != want {
t.Fatalf("addressList.first() = %v, want %v", got, want)
}
addressList.updateAddrs(addrs)
if got, want := addressList.first(), addressList.currentAddress(); got != want {
t.Fatalf("addressList.first() = %v, want %v", got, want)
}
if got, want := addressList.first(), addrs[0]; got != want {
t.Fatalf("addressList.first() = %v, want %v", got, want)
}
for i := 0; i < len(addrs); i++ {
if got, want := addressList.isValid(), true; got != want {
t.Fatalf("addressList.isValid() = %t, want %t", got, want)
}
if got, want := addressList.currentAddress(), addrs[i]; !want.Equal(got) {
t.Errorf("addressList.currentAddress() = %v, want %v", got, want)
}
if got, want := addressList.increment(), i+1 < len(addrs); got != want {
t.Fatalf("addressList.increment() = %t, want %t", got, want)
}
}
if got, want := addressList.isValid(), false; got != want {
t.Fatalf("addressList.isValid() = %t, want %t", got, want)
}
// increment an invalid address list.
if got, want := addressList.increment(), false; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
if got, want := addressList.isValid(), false; got != want {
t.Errorf("addressList.isValid() = %t, want %t", got, want)
}
addressList.reset()
for i := 0; i < len(addrs); i++ {
if got, want := addressList.isValid(), true; got != want {
t.Fatalf("addressList.isValid() = %t, want %t", got, want)
}
if got, want := addressList.currentAddress(), addrs[i]; !want.Equal(got) {
t.Errorf("addressList.currentAddress() = %v, want %v", got, want)
}
if got, want := addressList.increment(), i+1 < len(addrs); got != want {
t.Fatalf("addressList.increment() = %t, want %t", got, want)
}
}
}
// TestAddressList_SeekTo verifies the behaviour of addressList.seekTo.
func (s) TestAddressList_SeekTo(t *testing.T) {
addrs := []resolver.Address{
{
Addr: "192.168.1.1",
ServerName: "test-host-1",
Attributes: attributes.New("key-1", "val-1"),
BalancerAttributes: attributes.New("bal-key-1", "bal-val-1"),
},
{
Addr: "192.168.1.2",
ServerName: "test-host-2",
Attributes: attributes.New("key-2", "val-2"),
BalancerAttributes: attributes.New("bal-key-2", "bal-val-2"),
},
{
Addr: "192.168.1.3",
ServerName: "test-host-3",
Attributes: attributes.New("key-3", "val-3"),
BalancerAttributes: attributes.New("bal-key-3", "bal-val-3"),
},
}
addressList := addressList{}
addressList.updateAddrs(addrs)
// Try finding an address in the list.
key := resolver.Address{
Addr: "192.168.1.2",
ServerName: "test-host-2",
Attributes: attributes.New("key-2", "val-2"),
BalancerAttributes: attributes.New("ignored", "bal-val-2"),
}
if got, want := addressList.seekTo(key), true; got != want {
t.Errorf("addressList.seekTo(%v) = %t, want %t", key, got, want)
}
// It should be possible to increment once more now that the pointer has advanced.
if got, want := addressList.increment(), true; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
if got, want := addressList.increment(), false; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
// Seek to the key again, it is behind the pointer now.
if got, want := addressList.seekTo(key), true; got != want {
t.Errorf("addressList.seekTo(%v) = %t, want %t", key, got, want)
}
// Seek to a key not in the list.
key = resolver.Address{
Addr: "192.168.1.5",
ServerName: "test-host-5",
Attributes: attributes.New("key-5", "val-5"),
BalancerAttributes: attributes.New("ignored", "bal-val-5"),
}
if got, want := addressList.seekTo(key), false; got != want {
t.Errorf("addressList.seekTo(%v) = %t, want %t", key, got, want)
}
// It should be possible to increment once more since the pointer has not advanced.
if got, want := addressList.increment(), true; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
if got, want := addressList.increment(), false; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
}
// TestPickFirstLeaf_TFPickerUpdate sends TRANSIENT_FAILURE SubConn state updates
// for each SubConn managed by a pickfirst balancer. It verifies that the picker
// is updated with the expected frequency.
func (s) TestPickFirstLeaf_TFPickerUpdate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cc := testutils.NewBalancerClientConn(t)
bal := pickfirstBuilder{}.Build(cc, balancer.BuildOptions{})
defer bal.Close()
ccState := balancer.ClientConnState{
ResolverState: resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: "1.1.1.1:1"}}},
{Addresses: []resolver.Address{{Addr: "2.2.2.2:2"}}},
},
},
}
if err := bal.UpdateClientConnState(ccState); err != nil {
t.Fatalf("UpdateClientConnState(%v) returned error: %v", ccState, err)
}
// PF should report TRANSIENT_FAILURE only once all the sunbconns have failed
// once.
tfErr := fmt.Errorf("test err: connection refused")
sc1 := <-cc.NewSubConnCh
sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc1.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: tfErr})
if err := cc.WaitForPickerWithErr(ctx, balancer.ErrNoSubConnAvailable); err != nil {
t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", balancer.ErrNoSubConnAvailable, err)
}
sc2 := <-cc.NewSubConnCh
sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: tfErr})
if err := cc.WaitForPickerWithErr(ctx, tfErr); err != nil {
t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", tfErr, err)
}
// Subsequent TRANSIENT_FAILUREs should be reported only after seeing "# of SubConns"
// TRANSIENT_FAILUREs.
newTfErr := fmt.Errorf("test err: unreachable")
sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: newTfErr})
select {
case <-time.After(defaultTestShortTimeout):
case p := <-cc.NewPickerCh:
sc, err := p.Pick(balancer.PickInfo{})
t.Fatalf("Unexpected picker update: %v, %v", sc, err)
}
sc2.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.TransientFailure, ConnectionError: newTfErr})
if err := cc.WaitForPickerWithErr(ctx, newTfErr); err != nil {
t.Fatalf("cc.WaitForPickerWithErr(%v) returned error: %v", newTfErr, err)
}
}

View File

@ -1096,6 +1096,9 @@ func (s) TestUpdateStatePauses(t *testing.T) {
Init: func(bd *stub.BalancerData) {
bd.Data = balancer.Get(pickfirst.Name).Build(bd.ClientConn, bd.BuildOptions)
},
Close: func(bd *stub.BalancerData) {
bd.Data.(balancer.Balancer).Close()
},
ParseConfig: func(sc json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
cfg := &childPolicyConfig{}
if err := json.Unmarshal(sc, cfg); err != nil {

View File

@ -1249,6 +1249,8 @@ func (ac *addrConn) resetTransportAndUnlock() {
ac.mu.Unlock()
if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil {
// TODO: #7534 - Move re-resolution requests into the pick_first LB policy
// to ensure one resolution request per pass instead of per subconn failure.
ac.cc.resolveNow(resolver.ResolveNowOptions{})
ac.mu.Lock()
if acCtx.Err() != nil {

View File

@ -37,6 +37,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
internalbackoff "google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/transport"
@ -418,17 +419,21 @@ func (s) TestWithTransportCredentialsTLS(t *testing.T) {
// When creating a transport configured with n addresses, only calculate the
// backoff once per "round" of attempts instead of once per address (n times
// per "round" of attempts).
func (s) TestDial_OneBackoffPerRetryGroup(t *testing.T) {
// per "round" of attempts) for old pickfirst and once per address for new pickfirst.
func (s) TestDial_BackoffCountPerRetryGroup(t *testing.T) {
var attempts uint32
wantBackoffs := uint32(1)
if envconfig.NewPickFirstEnabled {
wantBackoffs = 2
}
getMinConnectTimeout := func() time.Duration {
if atomic.AddUint32(&attempts, 1) == 1 {
if atomic.AddUint32(&attempts, 1) <= wantBackoffs {
// Once all addresses are exhausted, hang around and wait for the
// client.Close to happen rather than re-starting a new round of
// attempts.
return time.Hour
}
t.Error("only one attempt backoff calculation, but got more")
t.Errorf("only %d attempt backoff calculation, but got more", wantBackoffs)
return 0
}
@ -499,6 +504,10 @@ func (s) TestDial_OneBackoffPerRetryGroup(t *testing.T) {
t.Fatal("timed out waiting for test to finish")
case <-server2Done:
}
if got, want := atomic.LoadUint32(&attempts), wantBackoffs; got != want {
t.Errorf("attempts = %d, want %d", got, want)
}
}
func (s) TestDialContextCancel(t *testing.T) {
@ -1062,18 +1071,14 @@ func (s) TestUpdateAddresses_NoopIfCalledWithSameAddresses(t *testing.T) {
}
// Grab the addrConn and call tryUpdateAddrs.
var ac *addrConn
client.mu.Lock()
for clientAC := range client.conns {
ac = clientAC
break
// Call UpdateAddresses with the same list of addresses, it should be a noop
// (even when the SubConn is Connecting, and doesn't have a curAddr).
clientAC.acbw.UpdateAddresses(clientAC.addrs)
}
client.mu.Unlock()
// Call UpdateAddresses with the same list of addresses, it should be a noop
// (even when the SubConn is Connecting, and doesn't have a curAddr).
ac.acbw.UpdateAddresses(addrsList)
// We've called tryUpdateAddrs - now let's make server2 close the
// connection and check that it continues to server3.
close(closeServer2)

View File

@ -575,6 +575,7 @@ func (s) TestBalancerGracefulSwitch(t *testing.T) {
bg.UpdateClientConnState(testBalancerIDs[0], balancer.ClientConnState{ResolverState: resolver.State{Addresses: testBackendAddrs[0:2]}})
bg.Start()
defer bg.Close()
m1 := make(map[resolver.Address]balancer.SubConn)
scs := make(map[balancer.SubConn]bool)
@ -604,6 +605,9 @@ func (s) TestBalancerGracefulSwitch(t *testing.T) {
Init: func(bd *stub.BalancerData) {
bd.Data = balancer.Get(pickfirst.Name).Build(bd.ClientConn, bd.BuildOptions)
},
Close: func(bd *stub.BalancerData) {
bd.Data.(balancer.Balancer).Close()
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
ccs.ResolverState.Addresses = ccs.ResolverState.Addresses[1:]
bal := bd.Data.(balancer.Balancer)

View File

@ -50,6 +50,11 @@ var (
// xDS fallback is turned on. If this is unset or is false, only the first
// xDS server in the list of server configs will be used.
XDSFallbackSupport = boolFromEnv("GRPC_EXPERIMENTAL_XDS_FALLBACK", false)
// NewPickFirstEnabled is set if the new pickfirst leaf policy is to be used
// instead of the exiting pickfirst implementation. This can be enabled by
// setting the environment variable "GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST"
// to "true".
NewPickFirstEnabled = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", false)
)
func boolFromEnv(envVar string, def bool) bool {

View File

@ -483,6 +483,9 @@ func (s) TestBalancerSwitch_Graceful(t *testing.T) {
pf := balancer.Get(pickfirst.Name)
bd.Data = pf.Build(bd.ClientConn, bd.BuildOptions)
},
Close: func(bd *stub.BalancerData) {
bd.Data.(balancer.Balancer).Close()
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
bal := bd.Data.(balancer.Balancer)
close(ccUpdateCh)

View File

@ -850,6 +850,9 @@ func (s) TestMetadataInPickResult(t *testing.T) {
cc := &testCCWrapper{ClientConn: bd.ClientConn}
bd.Data = balancer.Get(pickfirst.Name).Build(cc, bd.BuildOptions)
},
Close: func(bd *stub.BalancerData) {
bd.Data.(balancer.Balancer).Close()
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
bal := bd.Data.(balancer.Balancer)
return bal.UpdateClientConnState(ccs)

View File

@ -34,6 +34,7 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver"
@ -323,6 +324,13 @@ func (s) TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T)
client, err := grpc.Dial("whatever:///this-gets-overwritten",
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateRecordingBalancerName)),
grpc.WithConnectParams(grpc.ConnectParams{
// Set a really long back-off delay to ensure the first subConn does
// not enter IDLE before the second subConn connects.
Backoff: backoff.Config{
BaseDelay: 1 * time.Hour,
},
}),
grpc.WithResolvers(rb))
if err != nil {
t.Fatal(err)
@ -334,6 +342,16 @@ func (s) TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T)
connectivity.Connecting,
connectivity.Ready,
}
if envconfig.NewPickFirstEnabled {
want = []connectivity.State{
// The first subconn fails.
connectivity.Connecting,
connectivity.TransientFailure,
// The second subconn connects.
connectivity.Connecting,
connectivity.Ready,
}
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := 0; i < len(want); i++ {

View File

@ -162,6 +162,9 @@ func (s) TestResolverUpdate_InvalidServiceConfigAfterGoodUpdate(t *testing.T) {
pf := balancer.Get(pickfirst.Name)
bd.Data = pf.Build(bd.ClientConn, bd.BuildOptions)
},
Close: func(bd *stub.BalancerData) {
bd.Data.(balancer.Balancer).Close()
},
ParseConfig: func(lbCfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
cfg := &wrappingBalancerConfig{}
if err := json.Unmarshal(lbCfg, cfg); err != nil {

View File

@ -607,6 +607,7 @@ func TestClusterGracefulSwitch(t *testing.T) {
builder := balancer.Get(balancerName)
parser := builder.(balancer.ConfigParser)
bal := builder.Build(cc, balancer.BuildOptions{})
defer bal.Close()
configJSON1 := `{
"children": {
@ -644,6 +645,9 @@ func TestClusterGracefulSwitch(t *testing.T) {
Init: func(bd *stub.BalancerData) {
bd.Data = balancer.Get(pickfirst.Name).Build(bd.ClientConn, bd.BuildOptions)
},
Close: func(bd *stub.BalancerData) {
bd.Data.(balancer.Balancer).Close()
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
bal := bd.Data.(balancer.Balancer)
return bal.UpdateClientConnState(ccs)
@ -730,6 +734,7 @@ func (s) TestUpdateStatePauses(t *testing.T) {
builder := balancer.Get(balancerName)
parser := builder.(balancer.ConfigParser)
bal := builder.Build(cc, balancer.BuildOptions{})
defer bal.Close()
configJSON1 := `{
"children": {