mirror of https://github.com/grpc/grpc-go.git
examples: Add custom load balancer example (#6691)
This commit is contained in:
parent
fc8da03081
commit
431436d66b
|
@ -0,0 +1,293 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2024 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
// Package endpointsharding implements a load balancing policy that manages
|
||||
// homogenous child policies each owning a single endpoint.
|
||||
//
|
||||
// # Experimental
|
||||
//
|
||||
// Notice: This package is EXPERIMENTAL and may be changed or removed in a
|
||||
// later release.
|
||||
package endpointsharding
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/base"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/internal/balancer/gracefulswitch"
|
||||
"google.golang.org/grpc/internal/grpcrand"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/serviceconfig"
|
||||
)
|
||||
|
||||
// ChildState is the balancer state of a child along with the endpoint which
|
||||
// identifies the child balancer.
|
||||
type ChildState struct {
|
||||
Endpoint resolver.Endpoint
|
||||
State balancer.State
|
||||
}
|
||||
|
||||
// NewBalancer returns a load balancing policy that manages homogenous child
|
||||
// policies each owning a single endpoint.
|
||||
func NewBalancer(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
|
||||
es := &endpointSharding{
|
||||
cc: cc,
|
||||
bOpts: opts,
|
||||
}
|
||||
es.children.Store(resolver.NewEndpointMap())
|
||||
return es
|
||||
}
|
||||
|
||||
// endpointSharding is a balancer that wraps child balancers. It creates a child
|
||||
// balancer with child config for every unique Endpoint received. It updates the
|
||||
// child states on any update from parent or child.
|
||||
type endpointSharding struct {
|
||||
cc balancer.ClientConn
|
||||
bOpts balancer.BuildOptions
|
||||
|
||||
children atomic.Pointer[resolver.EndpointMap]
|
||||
|
||||
// inhibitChildUpdates is set during UpdateClientConnState/ResolverError
|
||||
// calls (calls to children will each produce an update, only want one
|
||||
// update).
|
||||
inhibitChildUpdates atomic.Bool
|
||||
|
||||
mu sync.Mutex // Sync updateState callouts and childState recent state updates
|
||||
}
|
||||
|
||||
// UpdateClientConnState creates a child for new endpoints and deletes children
|
||||
// for endpoints that are no longer present. It also updates all the children,
|
||||
// and sends a single synchronous update of the childrens' aggregated state at
|
||||
// the end of the UpdateClientConnState operation. If any endpoint has no
|
||||
// addresses, returns error without forwarding any updates. Otherwise returns
|
||||
// first error found from a child, but fully processes the new update.
|
||||
func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState) error {
|
||||
if len(state.ResolverState.Endpoints) == 0 {
|
||||
return errors.New("endpoints list is empty")
|
||||
}
|
||||
// Check/return early if any endpoints have no addresses.
|
||||
// TODO: make this configurable if needed.
|
||||
for i, endpoint := range state.ResolverState.Endpoints {
|
||||
if len(endpoint.Addresses) == 0 {
|
||||
return fmt.Errorf("endpoint %d has empty addresses", i)
|
||||
}
|
||||
}
|
||||
|
||||
es.inhibitChildUpdates.Store(true)
|
||||
defer func() {
|
||||
es.inhibitChildUpdates.Store(false)
|
||||
es.updateState()
|
||||
}()
|
||||
var ret error
|
||||
|
||||
children := es.children.Load()
|
||||
newChildren := resolver.NewEndpointMap()
|
||||
|
||||
// Update/Create new children.
|
||||
for _, endpoint := range state.ResolverState.Endpoints {
|
||||
if _, ok := newChildren.Get(endpoint); ok {
|
||||
// Endpoint child was already created, continue to avoid duplicate
|
||||
// update.
|
||||
continue
|
||||
}
|
||||
var bal *balancerWrapper
|
||||
if child, ok := children.Get(endpoint); ok {
|
||||
bal = child.(*balancerWrapper)
|
||||
} else {
|
||||
bal = &balancerWrapper{
|
||||
childState: ChildState{Endpoint: endpoint},
|
||||
ClientConn: es.cc,
|
||||
es: es,
|
||||
}
|
||||
bal.Balancer = gracefulswitch.NewBalancer(bal, es.bOpts)
|
||||
}
|
||||
newChildren.Set(endpoint, bal)
|
||||
if err := bal.UpdateClientConnState(balancer.ClientConnState{
|
||||
BalancerConfig: state.BalancerConfig,
|
||||
ResolverState: resolver.State{
|
||||
Endpoints: []resolver.Endpoint{endpoint},
|
||||
Attributes: state.ResolverState.Attributes,
|
||||
},
|
||||
}); err != nil && ret == nil {
|
||||
// Return first error found, and always commit full processing of
|
||||
// updating children. If desired to process more specific errors
|
||||
// across all endpoints, caller should make these specific
|
||||
// validations, this is a current limitation for simplicities sake.
|
||||
ret = err
|
||||
}
|
||||
}
|
||||
// Delete old children that are no longer present.
|
||||
for _, e := range children.Keys() {
|
||||
child, _ := children.Get(e)
|
||||
bal := child.(balancer.Balancer)
|
||||
if _, ok := newChildren.Get(e); !ok {
|
||||
bal.Close()
|
||||
}
|
||||
}
|
||||
es.children.Store(newChildren)
|
||||
return ret
|
||||
}
|
||||
|
||||
// ResolverError forwards the resolver error to all of the endpointSharding's
|
||||
// children and sends a single synchronous update of the childStates at the end
|
||||
// of the ResolverError operation.
|
||||
func (es *endpointSharding) ResolverError(err error) {
|
||||
es.inhibitChildUpdates.Store(true)
|
||||
defer func() {
|
||||
es.inhibitChildUpdates.Store(false)
|
||||
es.updateState()
|
||||
}()
|
||||
children := es.children.Load()
|
||||
for _, child := range children.Values() {
|
||||
bal := child.(balancer.Balancer)
|
||||
bal.ResolverError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (es *endpointSharding) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
|
||||
// UpdateSubConnState is deprecated.
|
||||
}
|
||||
|
||||
func (es *endpointSharding) Close() {
|
||||
children := es.children.Load()
|
||||
for _, child := range children.Values() {
|
||||
bal := child.(balancer.Balancer)
|
||||
bal.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// updateState updates this component's state. It sends the aggregated state,
|
||||
// and a picker with round robin behavior with all the child states present if
|
||||
// needed.
|
||||
func (es *endpointSharding) updateState() {
|
||||
if es.inhibitChildUpdates.Load() {
|
||||
return
|
||||
}
|
||||
var readyPickers, connectingPickers, idlePickers, transientFailurePickers []balancer.Picker
|
||||
|
||||
es.mu.Lock()
|
||||
defer es.mu.Unlock()
|
||||
|
||||
children := es.children.Load()
|
||||
childStates := make([]ChildState, 0, children.Len())
|
||||
|
||||
for _, child := range children.Values() {
|
||||
bw := child.(*balancerWrapper)
|
||||
childState := bw.childState
|
||||
childStates = append(childStates, childState)
|
||||
childPicker := childState.State.Picker
|
||||
switch childState.State.ConnectivityState {
|
||||
case connectivity.Ready:
|
||||
readyPickers = append(readyPickers, childPicker)
|
||||
case connectivity.Connecting:
|
||||
connectingPickers = append(connectingPickers, childPicker)
|
||||
case connectivity.Idle:
|
||||
idlePickers = append(idlePickers, childPicker)
|
||||
case connectivity.TransientFailure:
|
||||
transientFailurePickers = append(transientFailurePickers, childPicker)
|
||||
// connectivity.Shutdown shouldn't appear.
|
||||
}
|
||||
}
|
||||
|
||||
// Construct the round robin picker based off the aggregated state. Whatever
|
||||
// the aggregated state, use the pickers present that are currently in that
|
||||
// state only.
|
||||
var aggState connectivity.State
|
||||
var pickers []balancer.Picker
|
||||
if len(readyPickers) >= 1 {
|
||||
aggState = connectivity.Ready
|
||||
pickers = readyPickers
|
||||
} else if len(connectingPickers) >= 1 {
|
||||
aggState = connectivity.Connecting
|
||||
pickers = connectingPickers
|
||||
} else if len(idlePickers) >= 1 {
|
||||
aggState = connectivity.Idle
|
||||
pickers = idlePickers
|
||||
} else if len(transientFailurePickers) >= 1 {
|
||||
aggState = connectivity.TransientFailure
|
||||
pickers = transientFailurePickers
|
||||
} else {
|
||||
aggState = connectivity.TransientFailure
|
||||
pickers = []balancer.Picker{base.NewErrPicker(errors.New("no children to pick from"))}
|
||||
} // No children (resolver error before valid update).
|
||||
p := &pickerWithChildStates{
|
||||
pickers: pickers,
|
||||
childStates: childStates,
|
||||
next: uint32(grpcrand.Intn(len(pickers))),
|
||||
}
|
||||
es.cc.UpdateState(balancer.State{
|
||||
ConnectivityState: aggState,
|
||||
Picker: p,
|
||||
})
|
||||
}
|
||||
|
||||
// pickerWithChildStates delegates to the pickers it holds in a round robin
|
||||
// fashion. It also contains the childStates of all the endpointSharding's
|
||||
// children.
|
||||
type pickerWithChildStates struct {
|
||||
pickers []balancer.Picker
|
||||
childStates []ChildState
|
||||
next uint32
|
||||
}
|
||||
|
||||
func (p *pickerWithChildStates) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
|
||||
nextIndex := atomic.AddUint32(&p.next, 1)
|
||||
picker := p.pickers[nextIndex%uint32(len(p.pickers))]
|
||||
return picker.Pick(info)
|
||||
}
|
||||
|
||||
// ChildStatesFromPicker returns the state of all the children managed by the
|
||||
// endpoint sharding balancer that created this picker.
|
||||
func ChildStatesFromPicker(picker balancer.Picker) []ChildState {
|
||||
p, ok := picker.(*pickerWithChildStates)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return p.childStates
|
||||
}
|
||||
|
||||
// balancerWrapper is a wrapper of a balancer. It ID's a child balancer by
|
||||
// endpoint, and persists recent child balancer state.
|
||||
type balancerWrapper struct {
|
||||
balancer.Balancer // Simply forward balancer.Balancer operations.
|
||||
balancer.ClientConn // embed to intercept UpdateState, doesn't deal with SubConns
|
||||
|
||||
es *endpointSharding
|
||||
|
||||
childState ChildState
|
||||
}
|
||||
|
||||
func (bw *balancerWrapper) UpdateState(state balancer.State) {
|
||||
bw.es.mu.Lock()
|
||||
bw.childState.State = state
|
||||
bw.es.mu.Unlock()
|
||||
bw.es.updateState()
|
||||
}
|
||||
|
||||
func ParseConfig(cfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
|
||||
return gracefulswitch.ParseConfig(cfg)
|
||||
}
|
||||
|
||||
// PickFirstConfig is a pick first config without shuffling enabled.
|
||||
const PickFirstConfig = "[{\"pick_first\": {}}]"
|
|
@ -0,0 +1,159 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2024 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package endpointsharding
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/internal"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/internal/stubserver"
|
||||
"google.golang.org/grpc/internal/testutils/roundrobin"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/resolver/manual"
|
||||
"google.golang.org/grpc/serviceconfig"
|
||||
|
||||
testgrpc "google.golang.org/grpc/interop/grpc_testing"
|
||||
)
|
||||
|
||||
type s struct {
|
||||
grpctest.Tester
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
grpctest.RunSubTests(t, s{})
|
||||
}
|
||||
|
||||
var gracefulSwitchPickFirst serviceconfig.LoadBalancingConfig
|
||||
|
||||
var logger = grpclog.Component("endpoint-sharding-test")
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
gracefulSwitchPickFirst, err = ParseConfig(json.RawMessage(PickFirstConfig))
|
||||
if err != nil {
|
||||
logger.Fatal(err)
|
||||
}
|
||||
balancer.Register(fakePetioleBuilder{})
|
||||
}
|
||||
|
||||
const fakePetioleName = "fake_petiole"
|
||||
|
||||
type fakePetioleBuilder struct{}
|
||||
|
||||
func (fakePetioleBuilder) Name() string {
|
||||
return fakePetioleName
|
||||
}
|
||||
|
||||
func (fakePetioleBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
|
||||
fp := &fakePetiole{
|
||||
ClientConn: cc,
|
||||
bOpts: opts,
|
||||
}
|
||||
fp.Balancer = NewBalancer(fp, opts)
|
||||
return fp
|
||||
}
|
||||
|
||||
func (fakePetioleBuilder) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// fakePetiole is a load balancer that wraps the endpointShardingBalancer, and
|
||||
// forwards ClientConnUpdates with a child config of graceful switch that wraps
|
||||
// pick first. It also intercepts UpdateState to make sure it can access the
|
||||
// child state maintained by EndpointSharding.
|
||||
type fakePetiole struct {
|
||||
balancer.Balancer
|
||||
balancer.ClientConn
|
||||
bOpts balancer.BuildOptions
|
||||
}
|
||||
|
||||
func (fp *fakePetiole) UpdateClientConnState(state balancer.ClientConnState) error {
|
||||
if el := state.ResolverState.Endpoints; len(el) != 2 {
|
||||
return fmt.Errorf("UpdateClientConnState wants two endpoints, got: %v", el)
|
||||
}
|
||||
|
||||
return fp.Balancer.UpdateClientConnState(balancer.ClientConnState{
|
||||
BalancerConfig: gracefulSwitchPickFirst,
|
||||
ResolverState: state.ResolverState,
|
||||
})
|
||||
}
|
||||
|
||||
func (fp *fakePetiole) UpdateState(state balancer.State) {
|
||||
childStates := ChildStatesFromPicker(state.Picker)
|
||||
// Both child states should be present in the child picker. States and
|
||||
// picker change over the lifecycle of test, but there should always be two.
|
||||
if len(childStates) != 2 {
|
||||
logger.Fatal(fmt.Errorf("length of child states received: %v, want 2", len(childStates)))
|
||||
}
|
||||
|
||||
fp.ClientConn.UpdateState(state)
|
||||
}
|
||||
|
||||
// TestEndpointShardingBasic tests the basic functionality of the endpoint
|
||||
// sharding balancer. It specifies a petiole policy that is essentially a
|
||||
// wrapper around the endpoint sharder. Two backends are started, with each
|
||||
// backend's address specified in an endpoint. The petiole does not have a
|
||||
// special picker, so it should fallback to the default behavior, which is to
|
||||
// round_robin amongst the endpoint children that are in the aggregated state.
|
||||
// It also verifies the petiole has access to the raw child state in case it
|
||||
// wants to implement a custom picker.
|
||||
func (s) TestEndpointShardingBasic(t *testing.T) {
|
||||
backend1 := stubserver.StartTestService(t, nil)
|
||||
defer backend1.Stop()
|
||||
backend2 := stubserver.StartTestService(t, nil)
|
||||
defer backend2.Stop()
|
||||
|
||||
mr := manual.NewBuilderWithScheme("e2e-test")
|
||||
defer mr.Close()
|
||||
|
||||
json := `{"loadBalancingConfig": [{"fake_petiole":{}}]}`
|
||||
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(json)
|
||||
mr.InitialState(resolver.State{
|
||||
Endpoints: []resolver.Endpoint{
|
||||
{Addresses: []resolver.Address{{Addr: backend1.Address}}},
|
||||
{Addresses: []resolver.Address{{Addr: backend2.Address}}},
|
||||
},
|
||||
ServiceConfig: sc,
|
||||
})
|
||||
|
||||
cc, err := grpc.Dial(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
defer cc.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
client := testgrpc.NewTestServiceClient(cc)
|
||||
// Assert a round robin distribution between the two spun up backends. This
|
||||
// requires a poll and eventual consistency as both endpoint children do not
|
||||
// start in state READY.
|
||||
if err = roundrobin.CheckRoundRobinRPCs(ctx, client, []resolver.Address{{Addr: backend1.Address}, {Addr: backend2.Address}}); err != nil {
|
||||
t.Fatalf("error in expected round robin: %v", err)
|
||||
}
|
||||
}
|
|
@ -55,6 +55,7 @@ EXAMPLES=(
|
|||
"features/authz"
|
||||
"features/cancellation"
|
||||
"features/compression"
|
||||
"features/customloadbalancer"
|
||||
"features/deadline"
|
||||
"features/encryption/TLS"
|
||||
"features/error_details"
|
||||
|
@ -109,6 +110,7 @@ declare -A EXPECTED_SERVER_OUTPUT=(
|
|||
["features/authz"]="unary echoing message \"hello world\""
|
||||
["features/cancellation"]="server: error receiving from stream: rpc error: code = Canceled desc = context canceled"
|
||||
["features/compression"]="UnaryEcho called with message \"compress\""
|
||||
["features/customloadbalancer"]="serving on localhost:50051"
|
||||
["features/deadline"]=""
|
||||
["features/encryption/TLS"]=""
|
||||
["features/error_details"]=""
|
||||
|
@ -132,6 +134,7 @@ declare -A EXPECTED_CLIENT_OUTPUT=(
|
|||
["features/authz"]="UnaryEcho: hello world"
|
||||
["features/cancellation"]="cancelling context"
|
||||
["features/compression"]="UnaryEcho call returned \"compress\", <nil>"
|
||||
["features/customloadbalancer"]="Successful multiple iterations of 1:2 ratio"
|
||||
["features/deadline"]="wanted = DeadlineExceeded, got = DeadlineExceeded"
|
||||
["features/encryption/TLS"]="UnaryEcho: hello world"
|
||||
["features/error_details"]="Greeting: Hello world"
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# Custom Load Balancer
|
||||
|
||||
This example shows how to deploy a custom load balancer in a `ClientConn`.
|
||||
|
||||
## Try it
|
||||
|
||||
```
|
||||
go run server/main.go
|
||||
```
|
||||
|
||||
```
|
||||
go run client/main.go
|
||||
```
|
||||
|
||||
## Explanation
|
||||
|
||||
Two echo servers are serving on "localhost:20000" and "localhost:20001". They
|
||||
will include their serving address in the response. So the server on
|
||||
"localhost:20001" will reply to the RPC with `this is
|
||||
examples/customloadbalancing (from localhost:20001)`.
|
||||
|
||||
A client is created, to connect to both of these servers (they get both server
|
||||
addresses from the name resolver in two separate endpoints). The client is
|
||||
configured with the load balancer specified in the service config, which in this
|
||||
case is custom_round_robin.
|
||||
|
||||
### custom_round_robin
|
||||
|
||||
The client is configured to use `custom_round_robin`. `custom_round_robin`
|
||||
creates a pick first child for every endpoint it receives. It waits until both
|
||||
pick first children become ready, then defers to the first pick first child's
|
||||
picker, choosing the connection to localhost:20000, except every chooseSecond
|
||||
times, where it defers to second pick first child's picker, choosing the
|
||||
connection to localhost:20001 (or vice versa).
|
||||
|
||||
`custom_round_robin` is written as a delegating policy wrapping `pick_first`
|
||||
load balancers, one for every endpoint received. This is the intended way a user
|
||||
written custom lb should be specified, as pick first will contain a lot of
|
||||
useful functionality, such as Sticky Transient Failure, Happy Eyeballs, and
|
||||
Health Checking.
|
||||
|
||||
```
|
||||
this is examples/customloadbalancing (from localhost:50050)
|
||||
this is examples/customloadbalancing (from localhost:50050)
|
||||
this is examples/customloadbalancing (from localhost:50051)
|
||||
this is examples/customloadbalancing (from localhost:50050)
|
||||
this is examples/customloadbalancing (from localhost:50050)
|
||||
this is examples/customloadbalancing (from localhost:50051)
|
||||
this is examples/customloadbalancing (from localhost:50050)
|
||||
this is examples/customloadbalancing (from localhost:50050)
|
||||
this is examples/customloadbalancing (from localhost:50051)
|
||||
```
|
|
@ -0,0 +1,157 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2023 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package customroundrobin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
_ "google.golang.org/grpc" // to register pick_first
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/endpointsharding"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/serviceconfig"
|
||||
)
|
||||
|
||||
var gracefulSwitchPickFirst serviceconfig.LoadBalancingConfig
|
||||
|
||||
func init() {
|
||||
balancer.Register(customRoundRobinBuilder{})
|
||||
var err error
|
||||
gracefulSwitchPickFirst, err = endpointsharding.ParseConfig(json.RawMessage(endpointsharding.PickFirstConfig))
|
||||
if err != nil {
|
||||
logger.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
const customRRName = "custom_round_robin"
|
||||
|
||||
type customRRConfig struct {
|
||||
serviceconfig.LoadBalancingConfig `json:"-"`
|
||||
|
||||
// ChooseSecond represents how often pick iterations choose the second
|
||||
// SubConn in the list. Defaults to 3. If 0 never choose the second SubConn.
|
||||
ChooseSecond uint32 `json:"chooseSecond,omitempty"`
|
||||
}
|
||||
|
||||
type customRoundRobinBuilder struct{}
|
||||
|
||||
func (customRoundRobinBuilder) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
|
||||
lbConfig := &customRRConfig{
|
||||
ChooseSecond: 3,
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(s, lbConfig); err != nil {
|
||||
return nil, fmt.Errorf("custom-round-robin: unable to unmarshal customRRConfig: %v", err)
|
||||
}
|
||||
return lbConfig, nil
|
||||
}
|
||||
|
||||
func (customRoundRobinBuilder) Name() string {
|
||||
return customRRName
|
||||
}
|
||||
|
||||
func (customRoundRobinBuilder) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
|
||||
crr := &customRoundRobin{
|
||||
ClientConn: cc,
|
||||
bOpts: bOpts,
|
||||
}
|
||||
crr.Balancer = endpointsharding.NewBalancer(crr, bOpts)
|
||||
return crr
|
||||
}
|
||||
|
||||
var logger = grpclog.Component("example")
|
||||
|
||||
type customRoundRobin struct {
|
||||
// All state and operations on this balancer are either initialized at build
|
||||
// time and read only after, or are only accessed as part of its
|
||||
// balancer.Balancer API (UpdateState from children only comes in from
|
||||
// balancer.Balancer calls as well, and children are called one at a time),
|
||||
// in which calls are guaranteed to come synchronously. Thus, no extra
|
||||
// synchronization is required in this balancer.
|
||||
balancer.Balancer
|
||||
balancer.ClientConn
|
||||
bOpts balancer.BuildOptions
|
||||
|
||||
cfg atomic.Pointer[customRRConfig]
|
||||
}
|
||||
|
||||
func (crr *customRoundRobin) UpdateClientConnState(state balancer.ClientConnState) error {
|
||||
crrCfg, ok := state.BalancerConfig.(*customRRConfig)
|
||||
if !ok {
|
||||
return balancer.ErrBadResolverState
|
||||
}
|
||||
if el := state.ResolverState.Endpoints; len(el) != 2 {
|
||||
return fmt.Errorf("UpdateClientConnState wants two endpoints, got: %v", el)
|
||||
}
|
||||
crr.cfg.Store(crrCfg)
|
||||
// A call to UpdateClientConnState should always produce a new Picker. That
|
||||
// is guaranteed to happen since the aggregator will always call
|
||||
// UpdateChildState in its UpdateClientConnState.
|
||||
return crr.Balancer.UpdateClientConnState(balancer.ClientConnState{
|
||||
BalancerConfig: gracefulSwitchPickFirst,
|
||||
ResolverState: state.ResolverState,
|
||||
})
|
||||
}
|
||||
|
||||
func (crr *customRoundRobin) UpdateState(state balancer.State) {
|
||||
if state.ConnectivityState == connectivity.Ready {
|
||||
childStates := endpointsharding.ChildStatesFromPicker(state.Picker)
|
||||
var readyPickers []balancer.Picker
|
||||
for _, childState := range childStates {
|
||||
if childState.State.ConnectivityState == connectivity.Ready {
|
||||
readyPickers = append(readyPickers, childState.State.Picker)
|
||||
}
|
||||
}
|
||||
// If both children are ready, pick using the custom round robin
|
||||
// algorithm.
|
||||
if len(readyPickers) == 2 {
|
||||
picker := &customRoundRobinPicker{
|
||||
pickers: readyPickers,
|
||||
chooseSecond: crr.cfg.Load().ChooseSecond,
|
||||
next: 0,
|
||||
}
|
||||
crr.ClientConn.UpdateState(balancer.State{
|
||||
ConnectivityState: connectivity.Ready,
|
||||
Picker: picker,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
// Delegate to default behavior/picker from below.
|
||||
crr.ClientConn.UpdateState(state)
|
||||
}
|
||||
|
||||
type customRoundRobinPicker struct {
|
||||
pickers []balancer.Picker
|
||||
chooseSecond uint32
|
||||
next uint32
|
||||
}
|
||||
|
||||
func (crrp *customRoundRobinPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
|
||||
next := atomic.AddUint32(&crrp.next, 1)
|
||||
index := 0
|
||||
if next != 0 && next%crrp.chooseSecond == 0 {
|
||||
index = 1
|
||||
}
|
||||
childPicker := crrp.pickers[index%len(crrp.pickers)]
|
||||
return childPicker.Pick(info)
|
||||
}
|
|
@ -0,0 +1,154 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2023 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
_ "google.golang.org/grpc/examples/features/customloadbalancer/client/customroundrobin" // To register custom_round_robin.
|
||||
pb "google.golang.org/grpc/examples/features/proto/echo"
|
||||
"google.golang.org/grpc/internal"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/resolver/manual"
|
||||
"google.golang.org/grpc/serviceconfig"
|
||||
)
|
||||
|
||||
var (
|
||||
addr1 = "localhost:50050"
|
||||
addr2 = "localhost:50051"
|
||||
)
|
||||
|
||||
func main() {
|
||||
mr := manual.NewBuilderWithScheme("example")
|
||||
defer mr.Close()
|
||||
|
||||
// You can also plug in your own custom lb policy, which needs to be
|
||||
// configurable. This n is configurable. Try changing n and see how the
|
||||
// behavior changes.
|
||||
json := `{"loadBalancingConfig": [{"custom_round_robin":{"chooseSecond": 3}}]}`
|
||||
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(json)
|
||||
mr.InitialState(resolver.State{
|
||||
Endpoints: []resolver.Endpoint{
|
||||
{Addresses: []resolver.Address{{Addr: addr1}}},
|
||||
{Addresses: []resolver.Address{{Addr: addr2}}},
|
||||
},
|
||||
ServiceConfig: sc,
|
||||
})
|
||||
|
||||
cc, err := grpc.Dial(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
defer cc.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
ec := pb.NewEchoClient(cc)
|
||||
if err := waitForDistribution(ctx, ec); err != nil {
|
||||
log.Fatalf(err.Error())
|
||||
}
|
||||
fmt.Println("Successful multiple iterations of 1:2 ratio")
|
||||
}
|
||||
|
||||
// waitForDistribution makes RPC's on the echo client until 3 RPC's follow the
|
||||
// same 1:2 address ratio for the peer. Returns an error if fails to do so
|
||||
// before context timeout.
|
||||
func waitForDistribution(ctx context.Context, ec pb.EchoClient) error {
|
||||
for {
|
||||
results := make(map[string]uint32)
|
||||
InnerLoop:
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("timeout waiting for 1:2 distribution between addresses %v and %v", addr1, addr2)
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
res := make(map[string]uint32)
|
||||
for j := 0; j < 3; j++ {
|
||||
var peer peer.Peer
|
||||
r, err := ec.UnaryEcho(ctx, &pb.EchoRequest{Message: "this is examples/customloadbalancing"}, grpc.Peer(&peer))
|
||||
if err != nil {
|
||||
return fmt.Errorf("UnaryEcho failed: %v", err)
|
||||
}
|
||||
fmt.Println(r)
|
||||
peerAddr := peer.Addr.String()
|
||||
if !strings.HasSuffix(peerAddr, "50050") && !strings.HasSuffix(peerAddr, "50051") {
|
||||
return fmt.Errorf("peer address was not one of %v or %v, got: %v", addr1, addr2, peerAddr)
|
||||
}
|
||||
res[peerAddr]++
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
// Make sure the addresses come in a 1:2 ratio for this
|
||||
// iteration.
|
||||
var seen1, seen2 bool
|
||||
for addr, count := range res {
|
||||
if count != 1 && count != 2 {
|
||||
break InnerLoop
|
||||
}
|
||||
if count == 1 {
|
||||
if seen1 {
|
||||
break InnerLoop
|
||||
}
|
||||
seen1 = true
|
||||
}
|
||||
if count == 2 {
|
||||
if seen2 {
|
||||
break InnerLoop
|
||||
}
|
||||
seen2 = true
|
||||
}
|
||||
results[addr] = results[addr] + count
|
||||
}
|
||||
if !seen1 || !seen2 {
|
||||
break InnerLoop
|
||||
}
|
||||
}
|
||||
// Make sure iteration is 3 and 6 for addresses seen. This makes
|
||||
// sure the distribution is the same 1:2 ratio for each iteration.
|
||||
var seen3, seen6 bool
|
||||
for _, count := range results {
|
||||
if count != 3 && count != 6 {
|
||||
break InnerLoop
|
||||
}
|
||||
if count == 3 {
|
||||
if seen3 {
|
||||
break InnerLoop
|
||||
}
|
||||
seen3 = true
|
||||
}
|
||||
if count == 6 {
|
||||
if seen6 {
|
||||
break InnerLoop
|
||||
}
|
||||
seen6 = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if !seen3 || !seen6 {
|
||||
break InnerLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2023 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
pb "google.golang.org/grpc/examples/features/proto/echo"
|
||||
)
|
||||
|
||||
var (
|
||||
addrs = []string{"localhost:50050", "localhost:50051"}
|
||||
)
|
||||
|
||||
type echoServer struct {
|
||||
pb.UnimplementedEchoServer
|
||||
addr string
|
||||
}
|
||||
|
||||
func (s *echoServer) UnaryEcho(ctx context.Context, req *pb.EchoRequest) (*pb.EchoResponse, error) {
|
||||
return &pb.EchoResponse{Message: fmt.Sprintf("%s (from %s)", req.Message, s.addr)}, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
var wg sync.WaitGroup
|
||||
for _, addr := range addrs {
|
||||
lis, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
pb.RegisterEchoServer(s, &echoServer{
|
||||
addr: addr,
|
||||
})
|
||||
log.Printf("serving on %s\n", addr)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
|
@ -75,7 +75,6 @@ func ParseConfig(cfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error)
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing config for policy %q: %v", name, err)
|
||||
}
|
||||
|
||||
return &lbConfig{childBuilder: builder, childConfig: cfg}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -169,7 +169,6 @@ func (gsb *Balancer) latestBalancer() *balancerWrapper {
|
|||
func (gsb *Balancer) UpdateClientConnState(state balancer.ClientConnState) error {
|
||||
// The resolver data is only relevant to the most recent LB Policy.
|
||||
balToUpdate := gsb.latestBalancer()
|
||||
|
||||
gsbCfg, ok := state.BalancerConfig.(*lbConfig)
|
||||
if ok {
|
||||
// Switch to the child in the config unless it is already active.
|
||||
|
|
42
pickfirst.go
42
pickfirst.go
|
@ -54,7 +54,7 @@ type pfConfig struct {
|
|||
serviceconfig.LoadBalancingConfig `json:"-"`
|
||||
|
||||
// If set to true, instructs the LB policy to shuffle the order of the list
|
||||
// of addresses received from the name resolver before attempting to
|
||||
// of endpoints received from the name resolver before attempting to
|
||||
// connect to them.
|
||||
ShuffleAddressList bool `json:"shuffleAddressList"`
|
||||
}
|
||||
|
@ -94,8 +94,7 @@ func (b *pickfirstBalancer) ResolverError(err error) {
|
|||
}
|
||||
|
||||
func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
|
||||
addrs := state.ResolverState.Addresses
|
||||
if len(addrs) == 0 {
|
||||
if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 {
|
||||
// The resolver reported an empty address list. Treat it like an error by
|
||||
// calling b.ResolverError.
|
||||
if b.subConn != nil {
|
||||
|
@ -107,22 +106,49 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState
|
|||
b.ResolverError(errors.New("produced zero addresses"))
|
||||
return balancer.ErrBadResolverState
|
||||
}
|
||||
|
||||
// We don't have to guard this block with the env var because ParseConfig
|
||||
// already does so.
|
||||
cfg, ok := state.BalancerConfig.(pfConfig)
|
||||
if state.BalancerConfig != nil && !ok {
|
||||
return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v", state.BalancerConfig, state.BalancerConfig)
|
||||
}
|
||||
if cfg.ShuffleAddressList {
|
||||
addrs = append([]resolver.Address{}, addrs...)
|
||||
grpcrand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] })
|
||||
}
|
||||
|
||||
if b.logger.V(2) {
|
||||
b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState))
|
||||
}
|
||||
|
||||
var addrs []resolver.Address
|
||||
if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 {
|
||||
// Perform the optional shuffling described in gRFC A62. The shuffling will
|
||||
// change the order of endpoints but not touch the order of the addresses
|
||||
// within each endpoint. - A61
|
||||
if cfg.ShuffleAddressList {
|
||||
endpoints = append([]resolver.Endpoint{}, endpoints...)
|
||||
grpcrand.Shuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
|
||||
}
|
||||
|
||||
// "Flatten the list by concatenating the ordered list of addresses for each
|
||||
// of the endpoints, in order." - A61
|
||||
for _, endpoint := range endpoints {
|
||||
// "In the flattened list, interleave addresses from the two address
|
||||
// families, as per RFC-8304 section 4." - A61
|
||||
// TODO: support the above language.
|
||||
addrs = append(addrs, endpoint.Addresses...)
|
||||
}
|
||||
} else {
|
||||
// Endpoints not set, process addresses until we migrate resolver
|
||||
// emissions fully to Endpoints. The top channel does wrap emitted
|
||||
// addresses with endpoints, however some balancers such as weighted
|
||||
// target do not forwarrd the corresponding correct endpoints down/split
|
||||
// endpoints properly. Once all balancers correctly forward endpoints
|
||||
// down, can delete this else conditional.
|
||||
addrs = state.ResolverState.Addresses
|
||||
if cfg.ShuffleAddressList {
|
||||
addrs = append([]resolver.Address{}, addrs...)
|
||||
grpcrand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] })
|
||||
}
|
||||
}
|
||||
|
||||
if b.subConn != nil {
|
||||
b.cc.UpdateAddresses(b.subConn, addrs)
|
||||
return nil
|
||||
|
|
|
@ -397,7 +397,10 @@ func (s) TestPickFirst_ShuffleAddressList(t *testing.T) {
|
|||
|
||||
// Push an update with both addresses and shuffling disabled. We should
|
||||
// connect to backend 0.
|
||||
r.UpdateState(resolver.State{Addresses: []resolver.Address{addrs[0], addrs[1]}})
|
||||
r.UpdateState(resolver.State{Endpoints: []resolver.Endpoint{
|
||||
{Addresses: []resolver.Address{addrs[0]}},
|
||||
{Addresses: []resolver.Address{addrs[1]}},
|
||||
}})
|
||||
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -406,7 +409,10 @@ func (s) TestPickFirst_ShuffleAddressList(t *testing.T) {
|
|||
// but the channel should still be connected to backend 0.
|
||||
shufState := resolver.State{
|
||||
ServiceConfig: parseServiceConfig(t, r, serviceConfig),
|
||||
Addresses: []resolver.Address{addrs[0], addrs[1]},
|
||||
Endpoints: []resolver.Endpoint{
|
||||
{Addresses: []resolver.Address{addrs[0]}},
|
||||
{Addresses: []resolver.Address{addrs[1]}},
|
||||
},
|
||||
}
|
||||
r.UpdateState(shufState)
|
||||
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
|
||||
|
|
Loading…
Reference in New Issue