mirror of https://github.com/grpc/grpc-go.git
244 lines
7.4 KiB
Go
244 lines
7.4 KiB
Go
/*
|
|
*
|
|
* Copyright 2023 gRPC authors.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
*/
|
|
|
|
// Package leastrequest implements a least request load balancer.
|
|
package leastrequest
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
rand "math/rand/v2"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"google.golang.org/grpc/balancer"
|
|
"google.golang.org/grpc/balancer/endpointsharding"
|
|
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
|
|
"google.golang.org/grpc/connectivity"
|
|
"google.golang.org/grpc/grpclog"
|
|
internalgrpclog "google.golang.org/grpc/internal/grpclog"
|
|
"google.golang.org/grpc/resolver"
|
|
"google.golang.org/grpc/serviceconfig"
|
|
)
|
|
|
|
// Name is the name of the least request balancer.
|
|
const Name = "least_request_experimental"
|
|
|
|
var (
|
|
// randuint32 is a global to stub out in tests.
|
|
randuint32 = rand.Uint32
|
|
endpointShardingLBConfig = endpointsharding.PickFirstConfig
|
|
logger = grpclog.Component("least-request")
|
|
)
|
|
|
|
func init() {
|
|
balancer.Register(bb{})
|
|
}
|
|
|
|
// LBConfig is the balancer config for least_request_experimental balancer.
|
|
type LBConfig struct {
|
|
serviceconfig.LoadBalancingConfig `json:"-"`
|
|
|
|
// ChoiceCount is the number of random SubConns to sample to find the one
|
|
// with the fewest outstanding requests. If unset, defaults to 2. If set to
|
|
// < 2, the config will be rejected, and if set to > 10, will become 10.
|
|
ChoiceCount uint32 `json:"choiceCount,omitempty"`
|
|
}
|
|
|
|
type bb struct{}
|
|
|
|
func (bb) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
|
|
lbConfig := &LBConfig{
|
|
ChoiceCount: 2,
|
|
}
|
|
if err := json.Unmarshal(s, lbConfig); err != nil {
|
|
return nil, fmt.Errorf("least-request: unable to unmarshal LBConfig: %v", err)
|
|
}
|
|
// "If `choice_count < 2`, the config will be rejected." - A48
|
|
if lbConfig.ChoiceCount < 2 { // sweet
|
|
return nil, fmt.Errorf("least-request: lbConfig.choiceCount: %v, must be >= 2", lbConfig.ChoiceCount)
|
|
}
|
|
// "If a LeastRequestLoadBalancingConfig with a choice_count > 10 is
|
|
// received, the least_request_experimental policy will set choice_count =
|
|
// 10." - A48
|
|
if lbConfig.ChoiceCount > 10 {
|
|
lbConfig.ChoiceCount = 10
|
|
}
|
|
return lbConfig, nil
|
|
}
|
|
|
|
func (bb) Name() string {
|
|
return Name
|
|
}
|
|
|
|
func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer {
|
|
b := &leastRequestBalancer{
|
|
ClientConn: cc,
|
|
endpointRPCCounts: resolver.NewEndpointMap(),
|
|
}
|
|
b.child = endpointsharding.NewBalancer(b, bOpts)
|
|
b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[%p] ", b))
|
|
b.logger.Infof("Created")
|
|
return b
|
|
}
|
|
|
|
type leastRequestBalancer struct {
|
|
// Embeds balancer.Balancer because needs to intercept UpdateClientConnState
|
|
// to learn about choiceCount.
|
|
balancer.Balancer
|
|
// Embeds balancer.ClientConn because needs to intercept UpdateState calls
|
|
// from the child balancer.
|
|
balancer.ClientConn
|
|
child balancer.Balancer
|
|
logger *internalgrpclog.PrefixLogger
|
|
|
|
mu sync.Mutex
|
|
choiceCount uint32
|
|
// endpointRPCCounts holds RPC counts to keep track for subsequent picker
|
|
// updates.
|
|
endpointRPCCounts *resolver.EndpointMap // endpoint -> *atomic.Int32
|
|
}
|
|
|
|
func (lrb *leastRequestBalancer) Close() {
|
|
lrb.child.Close()
|
|
lrb.endpointRPCCounts = nil
|
|
}
|
|
|
|
func (lrb *leastRequestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
|
|
lrCfg, ok := ccs.BalancerConfig.(*LBConfig)
|
|
if !ok {
|
|
logger.Errorf("least-request: received config with unexpected type %T: %v", ccs.BalancerConfig, ccs.BalancerConfig)
|
|
return balancer.ErrBadResolverState
|
|
}
|
|
|
|
lrb.mu.Lock()
|
|
lrb.choiceCount = lrCfg.ChoiceCount
|
|
lrb.mu.Unlock()
|
|
// Enable the health listener in pickfirst children for client side health
|
|
// checks and outlier detection, if configured.
|
|
ccs.ResolverState = pickfirstleaf.EnableHealthListener(ccs.ResolverState)
|
|
ccs.BalancerConfig = endpointShardingLBConfig
|
|
return lrb.child.UpdateClientConnState(ccs)
|
|
}
|
|
|
|
type endpointState struct {
|
|
picker balancer.Picker
|
|
numRPCs *atomic.Int32
|
|
}
|
|
|
|
func (lrb *leastRequestBalancer) UpdateState(state balancer.State) {
|
|
var readyEndpoints []endpointsharding.ChildState
|
|
for _, child := range endpointsharding.ChildStatesFromPicker(state.Picker) {
|
|
if child.State.ConnectivityState == connectivity.Ready {
|
|
readyEndpoints = append(readyEndpoints, child)
|
|
}
|
|
}
|
|
|
|
// If no ready pickers are present, simply defer to the round robin picker
|
|
// from endpoint sharding, which will round robin across the most relevant
|
|
// pick first children in the highest precedence connectivity state.
|
|
if len(readyEndpoints) == 0 {
|
|
lrb.ClientConn.UpdateState(state)
|
|
return
|
|
}
|
|
|
|
lrb.mu.Lock()
|
|
defer lrb.mu.Unlock()
|
|
|
|
if logger.V(2) {
|
|
lrb.logger.Infof("UpdateState called with ready endpoints: %v", readyEndpoints)
|
|
}
|
|
|
|
// Reconcile endpoints.
|
|
newEndpoints := resolver.NewEndpointMap() // endpoint -> nil
|
|
for _, child := range readyEndpoints {
|
|
newEndpoints.Set(child.Endpoint, nil)
|
|
}
|
|
|
|
// If endpoints are no longer ready, no need to count their active RPCs.
|
|
for _, endpoint := range lrb.endpointRPCCounts.Keys() {
|
|
if _, ok := newEndpoints.Get(endpoint); !ok {
|
|
lrb.endpointRPCCounts.Delete(endpoint)
|
|
}
|
|
}
|
|
|
|
// Copy refs to counters into picker.
|
|
endpointStates := make([]endpointState, 0, len(readyEndpoints))
|
|
for _, child := range readyEndpoints {
|
|
var counter *atomic.Int32
|
|
if val, ok := lrb.endpointRPCCounts.Get(child.Endpoint); !ok {
|
|
// Create new counts if needed.
|
|
counter = new(atomic.Int32)
|
|
lrb.endpointRPCCounts.Set(child.Endpoint, counter)
|
|
} else {
|
|
counter = val.(*atomic.Int32)
|
|
}
|
|
endpointStates = append(endpointStates, endpointState{
|
|
picker: child.State.Picker,
|
|
numRPCs: counter,
|
|
})
|
|
}
|
|
|
|
lrb.ClientConn.UpdateState(balancer.State{
|
|
Picker: &picker{
|
|
choiceCount: lrb.choiceCount,
|
|
endpointStates: endpointStates,
|
|
},
|
|
ConnectivityState: connectivity.Ready,
|
|
})
|
|
}
|
|
|
|
type picker struct {
|
|
// choiceCount is the number of random endpoints to sample for choosing the
|
|
// one with the least requests.
|
|
choiceCount uint32
|
|
endpointStates []endpointState
|
|
}
|
|
|
|
func (p *picker) Pick(pInfo balancer.PickInfo) (balancer.PickResult, error) {
|
|
var pickedEndpointState *endpointState
|
|
var pickedEndpointNumRPCs int32
|
|
for i := 0; i < int(p.choiceCount); i++ {
|
|
index := randuint32() % uint32(len(p.endpointStates))
|
|
endpointState := p.endpointStates[index]
|
|
n := endpointState.numRPCs.Load()
|
|
if pickedEndpointState == nil || n < pickedEndpointNumRPCs {
|
|
pickedEndpointState = &endpointState
|
|
pickedEndpointNumRPCs = n
|
|
}
|
|
}
|
|
result, err := pickedEndpointState.picker.Pick(pInfo)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
// "The counter for a subchannel should be atomically incremented by one
|
|
// after it has been successfully picked by the picker." - A48
|
|
pickedEndpointState.numRPCs.Add(1)
|
|
// "the picker should add a callback for atomically decrementing the
|
|
// subchannel counter once the RPC finishes (regardless of Status code)." -
|
|
// A48.
|
|
originalDone := result.Done
|
|
result.Done = func(info balancer.DoneInfo) {
|
|
pickedEndpointState.numRPCs.Add(-1)
|
|
if originalDone != nil {
|
|
originalDone(info)
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|