Cherry-pick #8159 and #8243 to v1.72.x (#8255)

This commit is contained in:
Purnesh Dixit 2025-04-16 15:39:31 +05:30 committed by GitHub
parent 79ca1744ed
commit fd6f585291
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1155 additions and 190 deletions

View File

@ -55,6 +55,20 @@ var (
// setting the environment variable "GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST"
// to "false".
NewPickFirstEnabled = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", true)
// XDSEndpointHashKeyBackwardCompat controls the parsing of the endpoint hash
// key from EDS LbEndpoint metadata. Endpoint hash keys can be disabled by
// setting "GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT" to "true". When the
// implementation of A76 is stable, we will flip the default value to false
// in a subsequent release. A final release will remove this environment
// variable, enabling the new behavior unconditionally.
XDSEndpointHashKeyBackwardCompat = boolFromEnv("GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT", true)
// RingHashSetRequestHashKey is set if the ring hash balancer can get the
// request hash header by setting the "requestHashHeader" field, according
// to gRFC A76. It can be enabled by setting the environment variable
// "GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY" to "true".
RingHashSetRequestHashKey = boolFromEnv("GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY", false)
)
func boolFromEnv(envVar string, def bool) bool {

View File

@ -97,13 +97,11 @@ func hasNotPrintable(msg string) bool {
return false
}
// ValidatePair validate a key-value pair with the following rules (the pseudo-header will be skipped) :
//
// - key must contain one or more characters.
// - the characters in the key must be contained in [0-9 a-z _ - .].
// - if the key ends with a "-bin" suffix, no validation of the corresponding value is performed.
// - the characters in the every value must be printable (in [%x20-%x7E]).
func ValidatePair(key string, vals ...string) error {
// ValidateKey validates a key with the following rules (pseudo-headers are
// skipped):
// - the key must contain one or more characters.
// - the characters in the key must be in [0-9 a-z _ - .].
func ValidateKey(key string) error {
// key should not be empty
if key == "" {
return fmt.Errorf("there is an empty key in the header")
@ -119,6 +117,20 @@ func ValidatePair(key string, vals ...string) error {
return fmt.Errorf("header key %q contains illegal characters not in [0-9a-z-_.]", key)
}
}
return nil
}
// ValidatePair validates a key-value pair with the following rules
// (pseudo-header are skipped):
// - the key must contain one or more characters.
// - the characters in the key must be in [0-9 a-z _ - .].
// - if the key ends with a "-bin" suffix, no validation of the corresponding
// value is performed.
// - the characters in every value must be printable (in [%x20-%x7E]).
func ValidatePair(key string, vals ...string) error {
if err := ValidateKey(key); err != nil {
return err
}
if strings.HasSuffix(key, "-bin") {
return nil
}

View File

@ -0,0 +1,33 @@
package testutils
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import (
"testing"
)
// SetEnvConfig sets the value of the given variable to the specified value,
// taking care of restoring the original value after the test completes.
func SetEnvConfig[T any](t *testing.T, variable *T, value T) {
t.Helper()
old := *variable
t.Cleanup(func() {
*variable = old
})
*variable = value
}

View File

@ -26,6 +26,7 @@ import (
"github.com/envoyproxy/go-control-plane/pkg/wellknown"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/wrapperspb"
v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
@ -649,6 +650,9 @@ type BackendOptions struct {
HealthStatus v3corepb.HealthStatus
// Weight sets the backend weight. Defaults to 1.
Weight uint32
// Metadata sets the LB endpoint metadata (envoy.lb FilterMetadata field).
// See https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/core/v3/base.proto#envoy-v3-api-msg-config-core-v3-metadata
Metadata map[string]any
}
// EndpointOptions contains options to configure an Endpoint (or
@ -708,6 +712,10 @@ func EndpointResourceWithOptions(opts EndpointOptions) *v3endpointpb.ClusterLoad
},
}
}
metadata, err := structpb.NewStruct(b.Metadata)
if err != nil {
panic(err)
}
lbEndpoints = append(lbEndpoints, &v3endpointpb.LbEndpoint{
HostIdentifier: &v3endpointpb.LbEndpoint_Endpoint{Endpoint: &v3endpointpb.Endpoint{
Address: &v3corepb.Address{Address: &v3corepb.Address_SocketAddress{
@ -721,6 +729,11 @@ func EndpointResourceWithOptions(opts EndpointOptions) *v3endpointpb.ClusterLoad
}},
HealthStatus: b.HealthStatus,
LoadBalancingWeight: &wrapperspb.UInt32Value{Value: b.Weight},
Metadata: &v3corepb.Metadata{
FilterMetadata: map[string]*structpb.Struct{
"envoy.lb": metadata,
},
},
})
}

60
resolver/ringhash/attr.go Normal file
View File

@ -0,0 +1,60 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package ringhash implements resolver related functions for the ring_hash
// load balancing policy.
package ringhash
import (
"google.golang.org/grpc/resolver"
)
type hashKeyType string
// hashKeyKey is the key to store the ring hash key attribute in
// a resolver.Endpoint attribute.
const hashKeyKey = hashKeyType("grpc.resolver.ringhash.hash_key")
// SetHashKey sets the hash key for this endpoint. Combined with the ring_hash
// load balancing policy, it allows placing the endpoint on the ring based on an
// arbitrary string instead of the IP address. If hashKey is empty, the endpoint
// is returned unmodified.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetHashKey(endpoint resolver.Endpoint, hashKey string) resolver.Endpoint {
if hashKey == "" {
return endpoint
}
endpoint.Attributes = endpoint.Attributes.WithValue(hashKeyKey, hashKey)
return endpoint
}
// HashKey returns the hash key attribute of endpoint. If this attribute is
// not set, it returns the empty string.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func HashKey(endpoint resolver.Endpoint) string {
hashKey, _ := endpoint.Attributes.Value(hashKeyKey).(string)
return hashKey
}

View File

@ -27,6 +27,7 @@ import (
"google.golang.org/grpc/internal/hierarchy"
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/ringhash"
"google.golang.org/grpc/xds/internal"
"google.golang.org/grpc/xds/internal/balancer/clusterimpl"
"google.golang.org/grpc/xds/internal/balancer/outlierdetection"
@ -284,6 +285,7 @@ func priorityLocalitiesToClusterImpl(localities []xdsresource.Locality, priority
ew = endpoint.Weight
}
resolverEndpoint = weight.Set(resolverEndpoint, weight.EndpointInfo{Weight: lw * ew})
resolverEndpoint = ringhash.SetHashKey(resolverEndpoint, endpoint.HashKey)
retEndpoints = append(retEndpoints, resolverEndpoint)
}
}

View File

@ -21,8 +21,10 @@ package ringhash
import (
"encoding/json"
"fmt"
"strings"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/metadata"
"google.golang.org/grpc/serviceconfig"
)
@ -30,8 +32,9 @@ import (
type LBConfig struct {
serviceconfig.LoadBalancingConfig `json:"-"`
MinRingSize uint64 `json:"minRingSize,omitempty"`
MaxRingSize uint64 `json:"maxRingSize,omitempty"`
MinRingSize uint64 `json:"minRingSize,omitempty"`
MaxRingSize uint64 `json:"maxRingSize,omitempty"`
RequestHashHeader string `json:"requestHashHeader,omitempty"`
}
const (
@ -66,5 +69,18 @@ func parseConfig(c json.RawMessage) (*LBConfig, error) {
if cfg.MaxRingSize > envconfig.RingHashCap {
cfg.MaxRingSize = envconfig.RingHashCap
}
if !envconfig.RingHashSetRequestHashKey {
cfg.RequestHashHeader = ""
}
if cfg.RequestHashHeader != "" {
cfg.RequestHashHeader = strings.ToLower(cfg.RequestHashHeader)
// See rules in https://github.com/grpc/proposal/blob/master/A76-ring-hash-improvements.md#explicitly-setting-the-request-hash-key
if err := metadata.ValidateKey(cfg.RequestHashHeader); err != nil {
return nil, fmt.Errorf("invalid requestHashHeader %q: %v", cfg.RequestHashHeader, err)
}
if strings.HasSuffix(cfg.RequestHashHeader, "-bin") {
return nil, fmt.Errorf("invalid requestHashHeader %q: key must not end with \"-bin\"", cfg.RequestHashHeader)
}
}
return &cfg, nil
}

View File

@ -19,90 +19,146 @@
package ringhash
import (
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/testutils"
)
func (s) TestParseConfig(t *testing.T) {
tests := []struct {
name string
js string
envConfigCap uint64
want *LBConfig
wantErr bool
name string
js string
envConfigCap uint64
requestHeaderEnvVar bool
want *LBConfig
wantErr bool
}{
{
name: "OK",
js: `{"minRingSize": 1, "maxRingSize": 2}`,
want: &LBConfig{MinRingSize: 1, MaxRingSize: 2},
name: "OK",
js: `{"minRingSize": 1, "maxRingSize": 2}`,
requestHeaderEnvVar: true,
want: &LBConfig{MinRingSize: 1, MaxRingSize: 2},
},
{
name: "OK with default min",
js: `{"maxRingSize": 2000}`,
want: &LBConfig{MinRingSize: defaultMinSize, MaxRingSize: 2000},
name: "OK with default min",
js: `{"maxRingSize": 2000}`,
requestHeaderEnvVar: true,
want: &LBConfig{MinRingSize: defaultMinSize, MaxRingSize: 2000},
},
{
name: "OK with default max",
js: `{"minRingSize": 2000}`,
want: &LBConfig{MinRingSize: 2000, MaxRingSize: defaultMaxSize},
name: "OK with default max",
js: `{"minRingSize": 2000}`,
requestHeaderEnvVar: true,
want: &LBConfig{MinRingSize: 2000, MaxRingSize: defaultMaxSize},
},
{
name: "min greater than max",
js: `{"minRingSize": 10, "maxRingSize": 2}`,
want: nil,
wantErr: true,
name: "min greater than max",
js: `{"minRingSize": 10, "maxRingSize": 2}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "min greater than max greater than global limit",
js: `{"minRingSize": 6000, "maxRingSize": 5000}`,
want: nil,
wantErr: true,
name: "min greater than max greater than global limit",
js: `{"minRingSize": 6000, "maxRingSize": 5000}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "max greater than global limit",
js: `{"minRingSize": 1, "maxRingSize": 6000}`,
want: &LBConfig{MinRingSize: 1, MaxRingSize: 4096},
name: "max greater than global limit",
js: `{"minRingSize": 1, "maxRingSize": 6000}`,
requestHeaderEnvVar: true,
want: &LBConfig{MinRingSize: 1, MaxRingSize: 4096},
},
{
name: "min and max greater than global limit",
js: `{"minRingSize": 5000, "maxRingSize": 6000}`,
want: &LBConfig{MinRingSize: 4096, MaxRingSize: 4096},
name: "min and max greater than global limit",
js: `{"minRingSize": 5000, "maxRingSize": 6000}`,
requestHeaderEnvVar: true,
want: &LBConfig{MinRingSize: 4096, MaxRingSize: 4096},
},
{
name: "min and max less than raised global limit",
js: `{"minRingSize": 5000, "maxRingSize": 6000}`,
envConfigCap: 8000,
want: &LBConfig{MinRingSize: 5000, MaxRingSize: 6000},
name: "min and max less than raised global limit",
js: `{"minRingSize": 5000, "maxRingSize": 6000}`,
envConfigCap: 8000,
requestHeaderEnvVar: true,
want: &LBConfig{MinRingSize: 5000, MaxRingSize: 6000},
},
{
name: "min and max greater than raised global limit",
js: `{"minRingSize": 10000, "maxRingSize": 10000}`,
envConfigCap: 8000,
want: &LBConfig{MinRingSize: 8000, MaxRingSize: 8000},
name: "min and max greater than raised global limit",
js: `{"minRingSize": 10000, "maxRingSize": 10000}`,
envConfigCap: 8000,
requestHeaderEnvVar: true,
want: &LBConfig{MinRingSize: 8000, MaxRingSize: 8000},
},
{
name: "min greater than upper bound",
js: `{"minRingSize": 8388610, "maxRingSize": 10}`,
want: nil,
wantErr: true,
name: "min greater than upper bound",
js: `{"minRingSize": 8388610, "maxRingSize": 10}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "max greater than upper bound",
js: `{"minRingSize": 10, "maxRingSize": 8388610}`,
want: nil,
wantErr: true,
name: "max greater than upper bound",
js: `{"minRingSize": 10, "maxRingSize": 8388610}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "request metadata key set",
js: `{"requestHashHeader": "x-foo"}`,
requestHeaderEnvVar: true,
want: &LBConfig{
MinRingSize: defaultMinSize,
MaxRingSize: defaultMaxSize,
RequestHashHeader: "x-foo",
},
},
{
name: "request metadata key set with uppercase letters",
js: `{"requestHashHeader": "x-FOO"}`,
requestHeaderEnvVar: true,
want: &LBConfig{
MinRingSize: defaultMinSize,
MaxRingSize: defaultMaxSize,
RequestHashHeader: "x-foo",
},
},
{
name: "invalid request hash header",
js: `{"requestHashHeader": "!invalid"}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "binary request hash header",
js: `{"requestHashHeader": "header-with-bin"}`,
requestHeaderEnvVar: true,
want: nil,
wantErr: true,
},
{
name: "request hash header cleared when RingHashSetRequestHashKey env var is false",
js: `{"requestHashHeader": "x-foo"}`,
requestHeaderEnvVar: false,
want: &LBConfig{
MinRingSize: defaultMinSize,
MaxRingSize: defaultMaxSize,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envConfigCap != 0 {
old := envconfig.RingHashCap
defer func() { envconfig.RingHashCap = old }()
envconfig.RingHashCap = tt.envConfigCap
testutils.SetEnvConfig(t, &envconfig.RingHashCap, tt.envConfigCap)
}
got, err := parseConfig([]byte(tt.js))
testutils.SetEnvConfig(t, &envconfig.RingHashSetRequestHashKey, tt.requestHeaderEnvVar)
got, err := parseConfig(json.RawMessage(tt.js))
if (err != nil) != tt.wantErr {
t.Errorf("parseConfig() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@ -26,6 +26,8 @@ import (
rand "math/rand/v2"
"net"
"slices"
"strconv"
"sync"
"testing"
"time"
@ -48,6 +50,7 @@ import (
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/status"
"google.golang.org/grpc/xds/internal/balancer/ringhash"
v3clusterpb "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@ -123,9 +126,10 @@ func (s) TestRingHash_ReconnectToMoveOutOfTransientFailure(t *testing.T) {
defer cc.Close()
// Push the address of the test backend through the manual resolver.
r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
ctx = ringhash.SetXDSRequestHash(ctx, 0)
defer cancel()
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
@ -469,7 +473,7 @@ func (s) TestRingHash_AggregateClusterFallBackFromRingHashToLogicalDnsAtStartup(
}
dnsR := replaceDNSResolver(t)
dnsR.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: backends[0]}}})
dnsR.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: backends[0]}}})
if err := xdsServer.Update(ctx, updateOpts); err != nil {
t.Fatalf("Failed to update xDS resources: %v", err)
@ -547,7 +551,7 @@ func (s) TestRingHash_AggregateClusterFallBackFromRingHashToLogicalDnsAtStartupN
}
dnsR := replaceDNSResolver(t)
dnsR.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: backends[0]}}})
dnsR.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: backends[0]}}})
if err := xdsServer.Update(ctx, updateOpts); err != nil {
t.Fatalf("Failed to update xDS resources: %v", err)
@ -2542,3 +2546,373 @@ func (s) TestRingHash_RecoverWhenResolverRemovesEndpoint(t *testing.T) {
// Wait for channel to become READY without any pending RPC.
testutils.AwaitState(ctx, t, conn, connectivity.Ready)
}
// Tests that RPCs are routed according to endpoint hash key rather than
// endpoint first address if it is set in EDS endpoint metadata.
func (s) TestRingHash_EndpointHashKey(t *testing.T) {
testutils.SetEnvConfig(t, &envconfig.XDSEndpointHashKeyBackwardCompat, false)
backends := backendAddrs(startTestServiceBackends(t, 4))
const clusterName = "cluster"
var backendOpts []e2e.BackendOptions
for i, addr := range backends {
var ports []uint32
ports = append(ports, testutils.ParsePort(t, addr))
backendOpts = append(backendOpts, e2e.BackendOptions{
Ports: ports,
Metadata: map[string]any{"hash_key": strconv.Itoa(i)},
})
}
endpoints := e2e.EndpointResourceWithOptions(e2e.EndpointOptions{
ClusterName: clusterName,
Host: "localhost",
Localities: []e2e.LocalityOptions{{
Backends: backendOpts,
Weight: 1,
}},
})
cluster := e2e.ClusterResourceWithOptions(e2e.ClusterOptions{
ClusterName: clusterName,
ServiceName: clusterName,
Policy: e2e.LoadBalancingPolicyRingHash,
})
route := headerHashRoute("new_route", virtualHostName, clusterName, "address_hash")
listener := e2e.DefaultClientListener(virtualHostName, route.Name)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
xdsServer, nodeID, xdsResolver := setupManagementServerAndResolver(t)
if err := xdsServer.Update(ctx, xdsUpdateOpts(nodeID, endpoints, cluster, route, listener)); err != nil {
t.Fatalf("Failed to update xDS resources: %v", err)
}
opts := []grpc.DialOption{
grpc.WithResolvers(xdsResolver),
grpc.WithTransportCredentials(insecure.NewCredentials()),
}
conn, err := grpc.NewClient("xds:///test.server", opts...)
if err != nil {
t.Fatalf("Failed to create client: %s", err)
}
defer conn.Close()
client := testgrpc.NewTestServiceClient(conn)
// Make sure RPCs are routed to backends according to the endpoint metadata
// rather than their address. Note each type of RPC contains a header value
// that will always be hashed to a specific backend as the header value
// matches the endpoint metadata hash key.
for i, backend := range backends {
ctx := metadata.NewOutgoingContext(ctx, metadata.Pairs("address_hash", strconv.Itoa(i)+"_0"))
numRPCs := 10
reqPerBackend := checkRPCSendOK(ctx, t, client, numRPCs)
if reqPerBackend[backend] != numRPCs {
t.Errorf("Got RPC routed to addresses %v, want all RPCs routed to %v", reqPerBackend, backend)
}
}
// Update the endpoints to swap the metadata hash key.
for i := range backendOpts {
backendOpts[i].Metadata = map[string]any{"hash_key": strconv.Itoa(len(backends) - i - 1)}
}
endpoints = e2e.EndpointResourceWithOptions(e2e.EndpointOptions{
ClusterName: clusterName,
Host: "localhost",
Localities: []e2e.LocalityOptions{{
Backends: backendOpts,
Weight: 1,
}},
})
if err := xdsServer.Update(ctx, xdsUpdateOpts(nodeID, endpoints, cluster, route, listener)); err != nil {
t.Fatalf("Failed to update xDS resources: %v", err)
}
// Wait for the resolver update to make it to the balancer. This RPC should
// be routed to backend 3 with the reverse numbering of the hash_key
// attribute delivered above.
for {
ctx := metadata.NewOutgoingContext(ctx, metadata.Pairs("address_hash", "0_0"))
var remote peer.Peer
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&remote)); err != nil {
t.Fatalf("Unexpected RPC error waiting for EDS update propagation: %s", err)
}
if remote.Addr.String() == backends[3] {
break
}
}
// Now that the balancer has the new endpoint attributes, make sure RPCs are
// routed to backends according to the new endpoint metadata.
for i, backend := range backends {
ctx := metadata.NewOutgoingContext(ctx, metadata.Pairs("address_hash", strconv.Itoa(len(backends)-i-1)+"_0"))
numRPCs := 10
reqPerBackend := checkRPCSendOK(ctx, t, client, numRPCs)
if reqPerBackend[backend] != numRPCs {
t.Errorf("Got RPC routed to addresses %v, want all RPCs routed to %v", reqPerBackend, backend)
}
}
}
// Tests that when a request hash key is set in the balancer configuration via
// service config, this header is used to route to a specific backend.
func (s) TestRingHash_RequestHashKey(t *testing.T) {
testutils.SetEnvConfig(t, &envconfig.RingHashSetRequestHashKey, true)
backends := backendAddrs(startTestServiceBackends(t, 4))
// Create a clientConn with a manual resolver (which is used to push the
// address of the test backend), and a default service config pointing to
// the use of the ring_hash_experimental LB policy with an explicit hash
// header.
const ringHashServiceConfig = `{"loadBalancingConfig": [{"ring_hash_experimental":{"requestHashHeader":"address_hash"}}]}`
r := manual.NewBuilderWithScheme("whatever")
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(ringHashServiceConfig),
grpc.WithConnectParams(fastConnectParams),
}
cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatalf("Failed to dial local test server: %v", err)
}
defer cc.Close()
var endpoints []resolver.Endpoint
for _, backend := range backends {
endpoints = append(endpoints, resolver.Endpoint{
Addresses: []resolver.Address{{Addr: backend}},
})
}
r.UpdateState(resolver.State{
Endpoints: endpoints,
})
client := testgrpc.NewTestServiceClient(cc)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Note each type of RPC contains a header value that will always be hashed
// to a specific backend as the header value matches the value used to
// create the entry in the ring.
for _, backend := range backends {
ctx := metadata.NewOutgoingContext(ctx, metadata.Pairs("address_hash", backend+"_0"))
numRPCs := 10
reqPerBackend := checkRPCSendOK(ctx, t, client, numRPCs)
if reqPerBackend[backend] != numRPCs {
t.Errorf("Got RPC routed to addresses %v, want all RPCs routed to %v", reqPerBackend, backend)
}
}
const ringHashServiceConfigUpdate = `{"loadBalancingConfig": [{"ring_hash_experimental":{"requestHashHeader":"other_header"}}]}`
r.UpdateState(resolver.State{
Endpoints: endpoints,
ServiceConfig: (&testutils.ResolverClientConn{}).ParseServiceConfig(ringHashServiceConfigUpdate),
})
// Make sure that requests with the new hash are sent to the right backend.
for _, backend := range backends {
ctx := metadata.NewOutgoingContext(ctx, metadata.Pairs("other_header", backend+"_0"))
numRPCs := 10
reqPerBackend := checkRPCSendOK(ctx, t, client, numRPCs)
if reqPerBackend[backend] != numRPCs {
t.Errorf("Got RPC routed to addresses %v, want all RPCs routed to %v", reqPerBackend, backend)
}
}
}
// Tests that when a request hash key is set in the balancer configuration via
// service config, and the header is not set in the outgoing request, then it
// is sent to a random backend.
func (s) TestRingHash_RequestHashKeyRandom(t *testing.T) {
testutils.SetEnvConfig(t, &envconfig.RingHashSetRequestHashKey, true)
backends := backendAddrs(startTestServiceBackends(t, 4))
// Create a clientConn with a manual resolver (which is used to push the
// address of the test backend), and a default service config pointing to
// the use of the ring_hash_experimental LB policy with an explicit hash
// header.
const ringHashServiceConfig = `{"loadBalancingConfig": [{"ring_hash_experimental":{"requestHashHeader":"address_hash"}}]}`
r := manual.NewBuilderWithScheme("whatever")
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(ringHashServiceConfig),
grpc.WithConnectParams(fastConnectParams),
}
cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatalf("Failed to dial local test server: %v", err)
}
defer cc.Close()
var endpoints []resolver.Endpoint
for _, backend := range backends {
endpoints = append(endpoints, resolver.Endpoint{
Addresses: []resolver.Address{{Addr: backend}},
})
}
r.UpdateState(resolver.State{
Endpoints: endpoints,
})
client := testgrpc.NewTestServiceClient(cc)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Due to the way that ring hash lazily establishes connections when using a
// random hash, request distribution is skewed towards the order in which we
// connected. The test send RPCs until we are connected to all backends, so
// we can later assert that the distribution is uniform.
seen := make(map[string]bool)
for len(seen) != 4 {
var remote peer.Peer
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&remote)); err != nil {
t.Fatalf("rpc EmptyCall() failed: %v", err)
}
seen[remote.String()] = true
}
// Make sure that requests with the old hash are sent to random backends.
numRPCs := computeIdealNumberOfRPCs(t, .25, errorTolerance)
gotPerBackend := checkRPCSendOK(ctx, t, client, numRPCs)
for _, backend := range backends {
got := float64(gotPerBackend[backend]) / float64(numRPCs)
want := .25
if !cmp.Equal(got, want, cmpopts.EquateApprox(0, errorTolerance)) {
t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backend, got, want, errorTolerance)
}
}
}
// Tests that when a request hash key is set in the balancer configuration via
// service config, and the header is not set in the outgoing request (random
// behavior), then each RPC wakes up at most one SubChannel, and, if there are
// SubChannels in Ready state, RPCs are routed to them.
func (s) TestRingHash_RequestHashKeyConnecting(t *testing.T) {
testutils.SetEnvConfig(t, &envconfig.RingHashSetRequestHashKey, true)
backends := backendAddrs(startTestServiceBackends(t, 20))
// Create a clientConn with a manual resolver (which is used to push the
// address of the test backend), and a default service config pointing to
// the use of the ring_hash_experimental LB policy with an explicit hash
// header. Use a blocking dialer to control connection attempts.
const ringHashServiceConfig = `{"loadBalancingConfig": [
{"ring_hash_experimental":{"requestHashHeader":"address_hash"}}
]}`
r := manual.NewBuilderWithScheme("whatever")
blockingDialer := testutils.NewBlockingDialer()
dopts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithResolvers(r),
grpc.WithDefaultServiceConfig(ringHashServiceConfig),
grpc.WithConnectParams(fastConnectParams),
grpc.WithContextDialer(blockingDialer.DialContext),
}
cc, err := grpc.NewClient(r.Scheme()+":///test.server", dopts...)
if err != nil {
t.Fatalf("Failed to dial local test server: %v", err)
}
defer cc.Close()
var endpoints []resolver.Endpoint
for _, backend := range backends {
endpoints = append(endpoints, resolver.Endpoint{
Addresses: []resolver.Address{{Addr: backend}},
})
}
r.UpdateState(resolver.State{
Endpoints: endpoints,
})
client := testgrpc.NewTestServiceClient(cc)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
// Intercept all connection attempts to the backends.
var holds []*testutils.Hold
for i := 0; i < len(backends); i++ {
holds = append(holds, blockingDialer.Hold(backends[i]))
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
// Send 1 RPC and make sure this triggers at most 1 connection attempt.
_, err := client.EmptyCall(ctx, &testpb.Empty{})
if err != nil {
t.Errorf("EmptyCall(): got %v, want success", err)
}
wg.Done()
}()
testutils.AwaitState(ctx, t, cc, connectivity.Connecting)
// Check that only one connection attempt was started.
nConn := 0
for _, hold := range holds {
if hold.IsStarted() {
nConn++
}
}
if wantMaxConn := 1; nConn > wantMaxConn {
t.Fatalf("Got %d connection attempts, want at most %d", nConn, wantMaxConn)
}
// Do a second RPC. Since there should already be a SubChannel in
// Connecting state, this should not trigger a connection attempt.
wg.Add(1)
go func() {
_, err := client.EmptyCall(ctx, &testpb.Empty{})
if err != nil {
t.Errorf("EmptyCall(): got %v, want success", err)
}
wg.Done()
}()
// Give extra time for more connections to be attempted.
time.Sleep(defaultTestShortTimeout)
var firstConnectedBackend string
nConn = 0
for i, hold := range holds {
if hold.IsStarted() {
// Unblock the connection attempt. The SubChannel (and hence the
// channel) should transition to Ready. RPCs should succeed and
// be routed to this backend.
hold.Resume()
holds[i] = nil
firstConnectedBackend = backends[i]
nConn++
}
}
if wantMaxConn := 1; nConn > wantMaxConn {
t.Fatalf("Got %d connection attempts, want at most %d", nConn, wantMaxConn)
}
testutils.AwaitState(ctx, t, cc, connectivity.Ready)
wg.Wait() // Make sure we're done with the 2 previous RPCs.
// Now send RPCs until we have at least one more connection attempt, that
// is, the random hash did not land on the same backend on every pick (the
// chances are low, but we don't want this to be flaky). Make sure no RPC
// fails and that we route all of them to the only subchannel in ready
// state.
nConn = 0
for nConn == 0 {
p := peer.Peer{}
_, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&p))
if status.Code(err) == codes.DeadlineExceeded {
t.Fatal("EmptyCall(): test timed out while waiting for more connection attempts")
}
if err != nil {
t.Fatalf("EmptyCall(): got %v, want success", err)
}
if p.Addr.String() != firstConnectedBackend {
t.Errorf("RPC sent to backend %q, want %q", p.Addr.String(), firstConnectedBackend)
}
for _, hold := range holds {
if hold != nil && hold.IsStarted() {
nConn++
}
}
}
}

View File

@ -20,46 +20,103 @@ package ringhash
import (
"fmt"
"strings"
xxhash "github.com/cespare/xxhash/v2"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/metadata"
)
type picker struct {
ring *ring
logger *grpclog.PrefixLogger
// endpointStates is a cache of endpoint connectivity states and pickers.
ring *ring
// endpointStates is a cache of endpoint states.
// The ringhash balancer stores endpoint states in a `resolver.EndpointMap`,
// with access guarded by `ringhashBalancer.mu`. The `endpointStates` cache
// in the picker helps avoid locking the ringhash balancer's mutex when
// reading the latest state at RPC time.
endpointStates map[string]balancer.State // endpointState.firstAddr -> balancer.State
endpointStates map[string]endpointState // endpointState.hashKey -> endpointState
// requestHashHeader is the header key to look for the request hash. If it's
// empty, the request hash is expected to be set in the context via xDS.
// See gRFC A76.
requestHashHeader string
// hasEndpointInConnectingState is true if any of the endpoints is in
// CONNECTING.
hasEndpointInConnectingState bool
randUint64 func() uint64
}
func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
e := p.ring.pick(getRequestHash(info.Ctx))
ringSize := len(p.ring.items)
// Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF,
// we ignore all TF subchannels and find the first ring entry in READY,
// CONNECTING or IDLE. If that entry is in IDLE, we need to initiate a
// connection. The idlePicker returned by the LazyLB or the new Pickfirst
// should do this automatically.
for i := 0; i < ringSize; i++ {
index := (e.idx + i) % ringSize
balState := p.balancerState(p.ring.items[index])
switch balState.ConnectivityState {
case connectivity.Ready, connectivity.Connecting, connectivity.Idle:
return balState.Picker.Pick(info)
case connectivity.TransientFailure:
default:
panic(fmt.Sprintf("Found child balancer in unknown state: %v", balState.ConnectivityState))
usingRandomHash := false
var requestHash uint64
if p.requestHashHeader == "" {
var ok bool
if requestHash, ok = XDSRequestHash(info.Ctx); !ok {
return balancer.PickResult{}, fmt.Errorf("ringhash: expected xDS config selector to set the request hash")
}
} else {
md, ok := metadata.FromOutgoingContext(info.Ctx)
if !ok || len(md.Get(p.requestHashHeader)) == 0 {
requestHash = p.randUint64()
usingRandomHash = true
} else {
values := strings.Join(md.Get(p.requestHashHeader), ",")
requestHash = xxhash.Sum64String(values)
}
}
e := p.ring.pick(requestHash)
ringSize := len(p.ring.items)
if !usingRandomHash {
// Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF,
// we ignore all TF subchannels and find the first ring entry in READY,
// CONNECTING or IDLE. If that entry is in IDLE, we need to initiate a
// connection. The idlePicker returned by the LazyLB or the new Pickfirst
// should do this automatically.
for i := 0; i < ringSize; i++ {
index := (e.idx + i) % ringSize
es := p.endpointState(p.ring.items[index])
switch es.state.ConnectivityState {
case connectivity.Ready, connectivity.Connecting, connectivity.Idle:
return es.state.Picker.Pick(info)
case connectivity.TransientFailure:
default:
panic(fmt.Sprintf("Found child balancer in unknown state: %v", es.state.ConnectivityState))
}
}
} else {
// If the picker has generated a random hash, it will walk the ring from
// this hash, and pick the first READY endpoint. If no endpoint is
// currently in CONNECTING state, it will trigger a connection attempt
// on at most one endpoint that is in IDLE state along the way. - A76
requestedConnection := p.hasEndpointInConnectingState
for i := 0; i < ringSize; i++ {
index := (e.idx + i) % ringSize
es := p.endpointState(p.ring.items[index])
if es.state.ConnectivityState == connectivity.Ready {
return es.state.Picker.Pick(info)
}
if !requestedConnection && es.state.ConnectivityState == connectivity.Idle {
requestedConnection = true
// If the SubChannel is in idle state, initiate a connection but
// continue to check other pickers to see if there is one in
// ready state.
es.balancer.ExitIdle()
}
}
if requestedConnection {
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}
}
// All children are in transient failure. Return the first failure.
return p.balancerState(e).Picker.Pick(info)
return p.endpointState(e).state.Picker.Pick(info)
}
func (p *picker) balancerState(e *ringEntry) balancer.State {
return p.endpointStates[e.firstAddr]
func (p *picker) endpointState(e *ringEntry) endpointState {
return p.endpointStates[e.hashKey]
}

View File

@ -20,18 +20,22 @@ package ringhash
import (
"context"
"errors"
"fmt"
"math"
"testing"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/testutils"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/metadata"
)
var testSubConns []*testutils.TestSubConn
var (
testSubConns []*testutils.TestSubConn
errPicker = errors.New("picker in TransientFailure")
)
func init() {
for i := 0; i < 8; i++ {
@ -60,22 +64,35 @@ func (p *fakeChildPicker) Pick(balancer.PickInfo) (balancer.PickResult, error) {
}
}
func testRingAndEndpointStates(states []connectivity.State) (*ring, map[string]balancer.State) {
type fakeExitIdler struct {
sc *testutils.TestSubConn
}
func (ei *fakeExitIdler) ExitIdle() {
ei.sc.Connect()
}
func testRingAndEndpointStates(states []connectivity.State) (*ring, map[string]endpointState) {
var items []*ringEntry
epStates := map[string]balancer.State{}
epStates := map[string]endpointState{}
for i, st := range states {
testSC := testSubConns[i]
items = append(items, &ringEntry{
idx: i,
hash: uint64((i + 1) * 10),
firstAddr: testSC.String(),
idx: i,
hash: math.MaxUint64 / uint64(len(states)) * uint64(i),
hashKey: testSC.String(),
})
epState := balancer.State{
ConnectivityState: st,
Picker: &fakeChildPicker{
connectivityState: st,
tfError: fmt.Errorf("%d", i),
subConn: testSC,
epState := endpointState{
state: balancer.State{
ConnectivityState: st,
Picker: &fakeChildPicker{
connectivityState: st,
tfError: fmt.Errorf("%d: %w", i, errPicker),
subConn: testSC,
},
},
balancer: &fakeExitIdler{
sc: testSC,
},
}
epStates[testSC.String()] = epState
@ -87,7 +104,6 @@ func (s) TestPickerPickFirstTwo(t *testing.T) {
tests := []struct {
name string
connectivityStates []connectivity.State
hash uint64
wantSC balancer.SubConn
wantErr error
wantSCToConnect balancer.SubConn
@ -95,41 +111,40 @@ func (s) TestPickerPickFirstTwo(t *testing.T) {
{
name: "picked is Ready",
connectivityStates: []connectivity.State{connectivity.Ready, connectivity.Idle},
hash: 5,
wantSC: testSubConns[0],
},
{
name: "picked is connecting, queue",
connectivityStates: []connectivity.State{connectivity.Connecting, connectivity.Idle},
hash: 5,
wantErr: balancer.ErrNoSubConnAvailable,
},
{
name: "picked is Idle, connect and queue",
connectivityStates: []connectivity.State{connectivity.Idle, connectivity.Idle},
hash: 5,
wantErr: balancer.ErrNoSubConnAvailable,
wantSCToConnect: testSubConns[0],
},
{
name: "picked is TransientFailure, next is ready, return",
connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Ready},
hash: 5,
wantSC: testSubConns[1],
},
{
name: "picked is TransientFailure, next is connecting, queue",
connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Connecting},
hash: 5,
wantErr: balancer.ErrNoSubConnAvailable,
},
{
name: "picked is TransientFailure, next is Idle, connect and queue",
connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Idle},
hash: 5,
wantErr: balancer.ErrNoSubConnAvailable,
wantSCToConnect: testSubConns[1],
},
{
name: "all are in TransientFailure, return picked failure",
connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.TransientFailure},
wantErr: errPicker,
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
@ -138,13 +153,12 @@ func (s) TestPickerPickFirstTwo(t *testing.T) {
ring, epStates := testRingAndEndpointStates(tt.connectivityStates)
p := &picker{
ring: ring,
logger: internalgrpclog.NewPrefixLogger(logger, "test-ringhash-picker"),
endpointStates: epStates,
}
got, err := p.Pick(balancer.PickInfo{
Ctx: SetRequestHash(ctx, tt.hash),
Ctx: SetXDSRequestHash(ctx, 0), // always pick the first endpoint on the ring.
})
if err != tt.wantErr {
if (err != nil || tt.wantErr != nil) && !errors.Is(err, tt.wantErr) {
t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr)
return
}
@ -161,3 +175,136 @@ func (s) TestPickerPickFirstTwo(t *testing.T) {
})
}
}
func (s) TestPickerNoRequestHash(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ring, epStates := testRingAndEndpointStates([]connectivity.State{connectivity.Ready})
p := &picker{
ring: ring,
endpointStates: epStates,
}
if _, err := p.Pick(balancer.PickInfo{Ctx: ctx}); err == nil {
t.Errorf("Pick() should have failed with no request hash")
}
}
func (s) TestPickerRequestHashKey(t *testing.T) {
tests := []struct {
name string
headerValues []string
expectedPick int
}{
{
name: "header not set",
expectedPick: 0, // Random hash set to 0, which is within (MaxUint64 / 3 * 2, 0]
},
{
name: "header empty",
headerValues: []string{""},
expectedPick: 0, // xxhash.Sum64String("value1,value2") is within (MaxUint64 / 3 * 2, 0]
},
{
name: "header set to one value",
headerValues: []string{"some-value"},
expectedPick: 1, // xxhash.Sum64String("some-value") is within (0, MaxUint64 / 3]
},
{
name: "header set to multiple values",
headerValues: []string{"value1", "value2"},
expectedPick: 2, // xxhash.Sum64String("value1,value2") is within (MaxUint64 / 3, MaxUint64 / 3 * 2]
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ring, epStates := testRingAndEndpointStates(
[]connectivity.State{
connectivity.Ready,
connectivity.Ready,
connectivity.Ready,
})
headerName := "some-header"
p := &picker{
ring: ring,
endpointStates: epStates,
requestHashHeader: headerName,
randUint64: func() uint64 { return 0 },
}
for _, v := range tt.headerValues {
ctx = metadata.AppendToOutgoingContext(ctx, headerName, v)
}
if res, err := p.Pick(balancer.PickInfo{Ctx: ctx}); err != nil {
t.Errorf("Pick() failed: %v", err)
} else if res.SubConn != testSubConns[tt.expectedPick] {
t.Errorf("Pick() got = %v, want SubConn: %v", res.SubConn, testSubConns[tt.expectedPick])
}
})
}
}
func (s) TestPickerRandomHash(t *testing.T) {
tests := []struct {
name string
hash uint64
connectivityStates []connectivity.State
wantSC balancer.SubConn
wantErr error
wantSCToConnect balancer.SubConn
hasEndpointInConnectingState bool
}{
{
name: "header not set, picked is Ready",
connectivityStates: []connectivity.State{connectivity.Ready, connectivity.Idle},
wantSC: testSubConns[0],
},
{
name: "header not set, picked is Idle, another is Ready. Connect and pick Ready",
connectivityStates: []connectivity.State{connectivity.Idle, connectivity.Ready},
wantSC: testSubConns[1],
wantSCToConnect: testSubConns[0],
},
{
name: "header not set, picked is Idle, there is at least one Connecting",
connectivityStates: []connectivity.State{connectivity.Connecting, connectivity.Idle},
wantErr: balancer.ErrNoSubConnAvailable,
hasEndpointInConnectingState: true,
},
{
name: "header not set, all Idle or TransientFailure, connect",
connectivityStates: []connectivity.State{connectivity.TransientFailure, connectivity.Idle},
wantErr: balancer.ErrNoSubConnAvailable,
wantSCToConnect: testSubConns[1],
},
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ring, epStates := testRingAndEndpointStates(tt.connectivityStates)
p := &picker{
ring: ring,
endpointStates: epStates,
requestHashHeader: "some-header",
hasEndpointInConnectingState: tt.hasEndpointInConnectingState,
randUint64: func() uint64 { return 0 }, // always return the first endpoint on the ring.
}
if got, err := p.Pick(balancer.PickInfo{Ctx: ctx}); err != tt.wantErr {
t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr)
return
} else if got.SubConn != tt.wantSC {
t.Errorf("Pick() got = %v, want picked SubConn: %v", got, tt.wantSC)
}
if sc := tt.wantSCToConnect; sc != nil {
select {
case <-sc.(*testutils.TestSubConn).ConnectCh:
case <-time.After(defaultTestShortTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", sc)
}
}
})
}
}

View File

@ -33,23 +33,23 @@ type ring struct {
}
type endpointInfo struct {
firstAddr string
hashKey string
scaledWeight float64
originalWeight uint32
}
type ringEntry struct {
idx int
hash uint64
firstAddr string
weight uint32
idx int
hash uint64
hashKey string
weight uint32
}
// newRing creates a ring from the endpoints stored in the EndpointMap. The ring
// size is limited by the passed in max/min.
//
// ring entries will be created for each endpoint, and endpoints with high
// weight (specified by the address) may have multiple entries.
// weight (specified by the endpoint) may have multiple entries.
//
// For example, for endpoints with weights {a:3, b:3, c:4}, a generated ring of
// size 10 could be:
@ -109,8 +109,8 @@ func newRing(endpoints *resolver.EndpointMap[*endpointState], minRingSize, maxRi
// updates.
idx := 0
for currentHashes < targetHashes {
h := xxhash.Sum64String(epInfo.firstAddr + "_" + strconv.Itoa(idx))
items = append(items, &ringEntry{hash: h, firstAddr: epInfo.firstAddr, weight: epInfo.originalWeight})
h := xxhash.Sum64String(epInfo.hashKey + "_" + strconv.Itoa(idx))
items = append(items, &ringEntry{hash: h, hashKey: epInfo.hashKey, weight: epInfo.originalWeight})
idx++
currentHashes++
}
@ -153,7 +153,7 @@ func normalizeWeights(endpoints *resolver.EndpointMap[*endpointState]) ([]endpoi
// non-zero. So, we need not worry about divide by zero error here.
nw := float64(epState.weight) / float64(weightSum)
ret = append(ret, endpointInfo{
firstAddr: epState.firstAddr,
hashKey: epState.hashKey,
scaledWeight: nw,
originalWeight: epState.weight,
})
@ -166,7 +166,7 @@ func normalizeWeights(endpoints *resolver.EndpointMap[*endpointState]) ([]endpoi
// where an endpoint is added and then removed, the RPCs will still pick the
// same old endpoint.
sort.Slice(ret, func(i, j int) bool {
return ret[i].firstAddr < ret[j].firstAddr
return ret[i].hashKey < ret[j].hashKey
})
return ret, min
}

View File

@ -39,9 +39,9 @@ func init() {
testEndpoint("c", 4),
}
testEndpointStateMap = resolver.NewEndpointMap[*endpointState]()
testEndpointStateMap.Set(testEndpoints[0], &endpointState{firstAddr: "a", weight: 3})
testEndpointStateMap.Set(testEndpoints[1], &endpointState{firstAddr: "b", weight: 3})
testEndpointStateMap.Set(testEndpoints[2], &endpointState{firstAddr: "c", weight: 4})
testEndpointStateMap.Set(testEndpoints[0], &endpointState{hashKey: "a", weight: 3})
testEndpointStateMap.Set(testEndpoints[1], &endpointState{hashKey: "b", weight: 3})
testEndpointStateMap.Set(testEndpoints[2], &endpointState{hashKey: "c", weight: 4})
}
func testEndpoint(addr string, endpointWeight uint32) resolver.Endpoint {
@ -62,7 +62,7 @@ func (s) TestRingNew(t *testing.T) {
for _, e := range testEndpoints {
var count int
for _, ii := range r.items {
if ii.firstAddr == e.Addresses[0].Addr {
if ii.hashKey == hashKey(e) {
count++
}
}

View File

@ -23,6 +23,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math/rand/v2"
"sort"
"sync"
@ -36,6 +37,7 @@ import (
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/pretty"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/ringhash"
"google.golang.org/grpc/serviceconfig"
)
@ -94,6 +96,18 @@ type ringhashBalancer struct {
ring *ring
}
// hashKey returns the hash key to use for an endpoint. Per gRFC A61, each entry
// in the ring is a hash of the endpoint's hash key concatenated with a
// per-entry unique suffix.
func hashKey(endpoint resolver.Endpoint) string {
if hk := ringhash.HashKey(endpoint); hk != "" {
return hk
}
// If no hash key is set, use the endpoint's first address as the hash key.
// This is the default behavior when no hash key is set.
return endpoint.Addresses[0].Addr
}
// UpdateState intercepts child balancer state updates. It updates the
// per-endpoint state stored in the ring, and also the aggregated state based on
// the child picker. It also reconciles the endpoint list. It sets
@ -114,31 +128,29 @@ func (b *ringhashBalancer) UpdateState(state balancer.State) {
endpoint := childState.Endpoint
endpointsSet.Set(endpoint, true)
newWeight := getWeightAttribute(endpoint)
hk := hashKey(endpoint)
es, ok := b.endpointStates.Get(endpoint)
if !ok {
es = &endpointState{
balancer: childState.Balancer,
weight: newWeight,
firstAddr: endpoint.Addresses[0].Addr,
state: childState.State,
es := &endpointState{
balancer: childState.Balancer,
hashKey: hk,
weight: newWeight,
state: childState.State,
}
b.endpointStates.Set(endpoint, es)
b.shouldRegenerateRing = true
} else {
// We have seen this endpoint before and created a `endpointState`
// object for it. If the weight or the first address of the endpoint
// has changed, update the endpoint state map with the new weight.
// This will be used when a new ring is created.
// object for it. If the weight or the hash key of the endpoint has
// changed, update the endpoint state map with the new weight or
// hash key. This will be used when a new ring is created.
if oldWeight := es.weight; oldWeight != newWeight {
b.shouldRegenerateRing = true
es.weight = newWeight
}
if es.firstAddr != endpoint.Addresses[0].Addr {
// If the order of the addresses for a given endpoint change,
// that will change the position of the endpoint in the ring.
// -A61
if es.hashKey != hk {
b.shouldRegenerateRing = true
es.firstAddr = endpoint.Addresses[0].Addr
es.hashKey = hk
}
es.state = childState.State
}
@ -244,7 +256,7 @@ func (b *ringhashBalancer) updatePickerLocked() {
endpointStates[i] = s
}
sort.Slice(endpointStates, func(i, j int) bool {
return endpointStates[i].firstAddr < endpointStates[j].firstAddr
return endpointStates[i].hashKey < endpointStates[j].hashKey
})
var idleBalancer balancer.ExitIdler
for _, es := range endpointStates {
@ -278,7 +290,6 @@ func (b *ringhashBalancer) updatePickerLocked() {
} else {
newPicker = b.newPickerLocked()
}
b.logger.Infof("Pushing new state %v and picker %p", state, newPicker)
b.ClientConn.UpdateState(balancer.State{
ConnectivityState: state,
Picker: newPicker,
@ -299,11 +310,23 @@ func (b *ringhashBalancer) ExitIdle() {
// over to avoid locking the mutex at RPC time. The picker should be
// re-generated every time an endpoint state is updated.
func (b *ringhashBalancer) newPickerLocked() *picker {
states := make(map[string]balancer.State)
states := make(map[string]endpointState)
hasEndpointConnecting := false
for _, epState := range b.endpointStates.Values() {
states[epState.firstAddr] = epState.state
// Copy the endpoint state to avoid races, since ring hash
// mutates the state, weight and hash key in place.
states[epState.hashKey] = *epState
if epState.state.ConnectivityState == connectivity.Connecting {
hasEndpointConnecting = true
}
}
return &picker{
ring: b.ring,
endpointStates: states,
requestHashHeader: b.config.RequestHashHeader,
hasEndpointInConnectingState: hasEndpointConnecting,
randUint64: rand.Uint64,
}
return &picker{ring: b.ring, logger: b.logger, endpointStates: states}
}
// aggregatedStateLocked returns the aggregated child balancers state
@ -346,8 +369,7 @@ func (b *ringhashBalancer) aggregatedStateLocked() connectivity.State {
}
// getWeightAttribute is a convenience function which returns the value of the
// weight attribute stored in the BalancerAttributes field of addr, using the
// weightedroundrobin package.
// weight endpoint Attribute.
//
// When used in the xDS context, the weight attribute is guaranteed to be
// non-zero. But, when used in a non-xDS context, the weight attribute could be
@ -361,12 +383,13 @@ func getWeightAttribute(e resolver.Endpoint) uint32 {
}
type endpointState struct {
// firstAddr is the first address in the endpoint. Per gRFC A61, each entry
// in the ring is an endpoint, positioned based on the hash of the
// endpoint's first address.
firstAddr string
weight uint32
balancer balancer.ExitIdler
// hashKey is the hash key of the endpoint. Per gRFC A61, each entry in the
// ring is an endpoint, positioned based on the hash of the endpoint's first
// address by default. Per gRFC A76, the hash key of an endpoint may be
// overridden, for example based on EDS endpoint metadata.
hashKey string
weight uint32
balancer balancer.ExitIdler
// state is updated by the balancer while receiving resolver updates from
// the channel and picker updates from its children. Access to it is guarded

View File

@ -83,7 +83,7 @@ func setupTest(t *testing.T, endpoints []resolver.Endpoint) (*testutils.Balancer
t.Errorf("Number of child balancers = %d, want = %d", got, want)
}
for firstAddr, bs := range ringHashPicker.endpointStates {
if got, want := bs.ConnectivityState, connectivity.Idle; got != want {
if got, want := bs.state.ConnectivityState, connectivity.Idle; got != want {
t.Errorf("Child balancer connectivity state for address %q = %v, want = %v", firstAddr, got, want)
}
}
@ -144,7 +144,7 @@ func (s) TestOneEndpoint(t *testing.T) {
// only Endpoint which has a single address.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
}
var sc0 *testutils.TestSubConn
@ -172,7 +172,7 @@ func (s) TestOneEndpoint(t *testing.T) {
// Test pick with one backend.
p1 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc0 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0)
}
@ -205,7 +205,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
// SubConn.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
}
@ -216,7 +216,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
t.Fatalf("Timed out waiting for SubConn creation.")
case subConns[1] = <-cc.NewSubConnCh:
}
if got, want := subConns[1].Addresses[0].Addr, ring.items[1].firstAddr; got != want {
if got, want := subConns[1].Addresses[0].Addr, ring.items[1].hashKey; got != want {
t.Fatalf("SubConn.Address = %v, want = %v", got, want)
}
select {
@ -224,7 +224,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
case <-time.After(defaultTestTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", subConns[1])
}
delete(remainingAddrs, ring.items[1].firstAddr)
delete(remainingAddrs, ring.items[1].hashKey)
// Turn down the subConn in use.
subConns[1].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
@ -248,9 +248,9 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
case <-time.After(defaultTestTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", subConns[1])
}
if scAddr == ring.items[0].firstAddr {
if scAddr == ring.items[0].hashKey {
subConns[0] = sc
} else if scAddr == ring.items[2].firstAddr {
} else if scAddr == ring.items[2].hashKey {
subConns[2] = sc
}
@ -273,9 +273,9 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
case <-time.After(defaultTestTimeout):
t.Errorf("timeout waiting for Connect() from SubConn %v", subConns[1])
}
if scAddr == ring.items[0].firstAddr {
if scAddr == ring.items[0].hashKey {
subConns[0] = sc
} else if scAddr == ring.items[2].firstAddr {
} else if scAddr == ring.items[2].hashKey {
subConns[2] = sc
}
sc.UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Connecting})
@ -292,7 +292,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
}
p1 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != subConns[0] {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, subConns[0])
}
@ -305,7 +305,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
subConns[2].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
p2 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != subConns[2] {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, subConns[2])
}
@ -318,7 +318,7 @@ func (s) TestThreeSubConnsAffinity(t *testing.T) {
subConns[1].UpdateState(balancer.SubConnState{ConnectivityState: connectivity.Ready})
p3 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p3.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
gotSCSt, _ := p3.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != subConns[1] {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, subConns[1])
}
@ -346,7 +346,7 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) {
// SubConn.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
}
// The picked SubConn should be the second in the ring.
@ -356,7 +356,7 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) {
t.Fatalf("Timed out waiting for SubConn creation.")
case sc0 = <-cc.NewSubConnCh:
}
if got, want := sc0.Addresses[0].Addr, ring0.items[1].firstAddr; got != want {
if got, want := sc0.Addresses[0].Addr, ring0.items[1].hashKey; got != want {
t.Fatalf("SubConn.Address = %v, want = %v", got, want)
}
select {
@ -375,7 +375,7 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) {
// First hash should always pick sc0.
p1 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
gotSCSt, _ := p1.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc0 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0)
}
@ -384,7 +384,7 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) {
secondHash := ring0.items[1].hash
// secondHash+1 will pick the third SubConn from the ring.
testHash2 := secondHash + 1
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash2)}); err != balancer.ErrNoSubConnAvailable {
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash2)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
}
var sc1 *testutils.TestSubConn
@ -393,7 +393,7 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) {
t.Fatalf("Timed out waiting for SubConn creation.")
case sc1 = <-cc.NewSubConnCh:
}
if got, want := sc1.Addresses[0].Addr, ring0.items[2].firstAddr; got != want {
if got, want := sc1.Addresses[0].Addr, ring0.items[2].hashKey; got != want {
t.Fatalf("SubConn.Address = %v, want = %v", got, want)
}
select {
@ -407,14 +407,14 @@ func (s) TestThreeBackendsAffinityMultiple(t *testing.T) {
// With the new generated picker, hash2 always picks sc1.
p2 := <-cc.NewPickerCh
for i := 0; i < 5; i++ {
gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash2)})
gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash2)})
if gotSCSt.SubConn != sc1 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1)
}
}
// But the first hash still picks sc0.
for i := 0; i < 5; i++ {
gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, testHash)})
gotSCSt, _ := p2.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, testHash)})
if gotSCSt.SubConn != sc0 {
t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0)
}
@ -504,14 +504,14 @@ func (s) TestAddrWeightChange(t *testing.T) {
t.Fatalf("new picker after changing address weight has %d entries, want 3", len(p3.(*picker).ring.items))
}
for _, i := range p3.(*picker).ring.items {
if i.firstAddr == testBackendAddrStrs[0] {
if i.hashKey == testBackendAddrStrs[0] {
if i.weight != 1 {
t.Fatalf("new picker after changing address weight has weight %d for %v, want 1", i.weight, i.firstAddr)
t.Fatalf("new picker after changing address weight has weight %d for %v, want 1", i.weight, i.hashKey)
}
}
if i.firstAddr == testBackendAddrStrs[1] {
if i.hashKey == testBackendAddrStrs[1] {
if i.weight != 2 {
t.Fatalf("new picker after changing address weight has weight %d for %v, want 2", i.weight, i.firstAddr)
t.Fatalf("new picker after changing address weight has weight %d for %v, want 2", i.weight, i.hashKey)
}
}
}
@ -532,6 +532,7 @@ func (s) TestAutoConnectEndpointOnTransientFailure(t *testing.T) {
// ringhash won't tell SCs to connect until there is an RPC, so simulate
// one now.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
ctx = SetXDSRequestHash(ctx, 0)
defer cancel()
p0.Pick(balancer.PickInfo{Ctx: ctx})
@ -690,7 +691,7 @@ func (s) TestAddrBalancerAttributesChange(t *testing.T) {
// only Endpoint which has a single address.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetRequestHash(ctx, firstHash)}); err != balancer.ErrNoSubConnAvailable {
if _, err := p0.Pick(balancer.PickInfo{Ctx: SetXDSRequestHash(ctx, firstHash)}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("first pick returned err %v, want %v", err, balancer.ErrNoSubConnAvailable)
}
select {

View File

@ -18,23 +18,25 @@
package ringhash
import "context"
import (
"context"
)
type clusterKey struct{}
type xdsHashKey struct{}
func getRequestHash(ctx context.Context) uint64 {
requestHash, _ := ctx.Value(clusterKey{}).(uint64)
return requestHash
// XDSRequestHash returns the request hash in the context and true if it was set
// from the xDS config selector. If the xDS config selector has not set the hash,
// it returns 0 and false.
func XDSRequestHash(ctx context.Context) (uint64, bool) {
requestHash := ctx.Value(xdsHashKey{})
if requestHash == nil {
return 0, false
}
return requestHash.(uint64), true
}
// GetRequestHashForTesting returns the request hash in the context; to be used
// for testing only.
func GetRequestHashForTesting(ctx context.Context) uint64 {
return getRequestHash(ctx)
}
// SetRequestHash adds the request hash to the context for use in Ring Hash Load
// Balancing.
func SetRequestHash(ctx context.Context, requestHash uint64) context.Context {
return context.WithValue(ctx, clusterKey{}, requestHash)
// SetXDSRequestHash adds the request hash to the context for use in Ring Hash
// Load Balancing using xDS route hash_policy.
func SetXDSRequestHash(ctx context.Context, requestHash uint64) context.Context {
return context.WithValue(ctx, xdsHashKey{}, requestHash)
}

View File

@ -203,7 +203,7 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP
}
lbCtx := clustermanager.SetPickedCluster(rpcInfo.Context, cluster.name)
lbCtx = ringhash.SetRequestHash(lbCtx, cs.generateHash(rpcInfo, rt.hashPolicies))
lbCtx = ringhash.SetXDSRequestHash(lbCtx, cs.generateHash(rpcInfo, rt.hashPolicies))
config := &iresolver.RPCConfig{
// Communicate to the LB policy the chosen cluster and request hash, if Ring Hash LB policy.

View File

@ -495,8 +495,11 @@ func (s) TestResolverRequestHash(t *testing.T) {
if err != nil {
t.Fatalf("cs.SelectConfig(): %v", err)
}
gotHash := ringhash.GetRequestHashForTesting(res.Context)
wantHash := xxhash.Sum64String("/products")
gotHash, ok := ringhash.XDSRequestHash(res.Context)
if !ok {
t.Fatalf("Got no request hash, want: %v", wantHash)
}
if gotHash != wantHash {
t.Fatalf("Got request hash: %v, want: %v", gotHash, wantHash)
}

View File

@ -52,6 +52,7 @@ type Endpoint struct {
Addresses []string
HealthStatus EndpointHealthStatus
Weight uint32
HashKey string
}
// Locality contains information of a locality.

View File

@ -111,11 +111,31 @@ func parseEndpoints(lbEndpoints []*v3endpointpb.LbEndpoint, uniqueEndpointAddrs
HealthStatus: EndpointHealthStatus(lbEndpoint.GetHealthStatus()),
Addresses: addrs,
Weight: weight,
HashKey: hashKey(lbEndpoint),
})
}
return endpoints, nil
}
// hashKey extracts and returns the hash key from the given LbEndpoint. If no
// hash key is found, it returns an empty string.
func hashKey(lbEndpoint *v3endpointpb.LbEndpoint) string {
// "The xDS resolver, described in A74, will be changed to set the hash_key
// endpoint attribute to the value of LbEndpoint.Metadata envoy.lb hash_key
// field, as described in Envoy's documentation for the ring hash load
// balancer." - A76
if envconfig.XDSEndpointHashKeyBackwardCompat {
return ""
}
envoyLB := lbEndpoint.GetMetadata().GetFilterMetadata()["envoy.lb"]
if envoyLB != nil {
if h := envoyLB.GetFields()["hash_key"]; h != nil {
return h.GetStringValue()
}
}
return ""
}
func parseEDSRespProto(m *v3endpointpb.ClusterLoadAssignment) (EndpointsUpdate, error) {
ret := EndpointsUpdate{}
for _, dropPolicy := range m.GetPolicy().GetDropOverloads() {

View File

@ -34,7 +34,9 @@ import (
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/xds/internal"
"google.golang.org/grpc/xds/internal/xdsclient/xdsresource/version"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
@ -333,6 +335,135 @@ func (s) TestEDSParseRespProtoAdditionalAddrs(t *testing.T) {
}
}
func (s) TestUnmarshalEndpointHashKey(t *testing.T) {
baseCLA := &v3endpointpb.ClusterLoadAssignment{
Endpoints: []*v3endpointpb.LocalityLbEndpoints{
{
Locality: &v3corepb.Locality{Region: "r"},
LbEndpoints: []*v3endpointpb.LbEndpoint{
{
HostIdentifier: &v3endpointpb.LbEndpoint_Endpoint{
Endpoint: &v3endpointpb.Endpoint{
Address: &v3corepb.Address{
Address: &v3corepb.Address_SocketAddress{
SocketAddress: &v3corepb.SocketAddress{
Address: "test-address",
PortSpecifier: &v3corepb.SocketAddress_PortValue{
PortValue: 8080,
},
},
},
},
},
},
},
},
LoadBalancingWeight: &wrapperspb.UInt32Value{Value: 1},
},
},
}
tests := []struct {
name string
metadata *v3corepb.Metadata
wantHashKey string
compatEnvVar bool
}{
{
name: "no metadata",
metadata: nil,
wantHashKey: "",
},
{
name: "empty metadata",
metadata: &v3corepb.Metadata{},
wantHashKey: "",
},
{
name: "filter metadata without envoy.lb",
metadata: &v3corepb.Metadata{
FilterMetadata: map[string]*structpb.Struct{
"test-filter": {},
},
},
wantHashKey: "",
},
{
name: "nil envoy.lb",
metadata: &v3corepb.Metadata{
FilterMetadata: map[string]*structpb.Struct{
"envoy.lb": nil,
},
},
wantHashKey: "",
},
{
name: "envoy.lb without hash key",
metadata: &v3corepb.Metadata{
FilterMetadata: map[string]*structpb.Struct{
"envoy.lb": {
Fields: map[string]*structpb.Value{
"hash_key": {
Kind: &structpb.Value_NumberValue{NumberValue: 123.0},
},
},
},
},
},
wantHashKey: "",
},
{
name: "envoy.lb with hash key, compat mode off",
metadata: &v3corepb.Metadata{
FilterMetadata: map[string]*structpb.Struct{
"envoy.lb": {
Fields: map[string]*structpb.Value{
"hash_key": {
Kind: &structpb.Value_StringValue{StringValue: "test-hash-key"},
},
},
},
},
},
wantHashKey: "test-hash-key",
},
{
name: "envoy.lb with hash key, compat mode on",
metadata: &v3corepb.Metadata{
FilterMetadata: map[string]*structpb.Struct{
"envoy.lb": {
Fields: map[string]*structpb.Value{
"hash_key": {
Kind: &structpb.Value_StringValue{StringValue: "test-hash-key"},
},
},
},
},
},
wantHashKey: "",
compatEnvVar: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
testutils.SetEnvConfig(t, &envconfig.XDSEndpointHashKeyBackwardCompat, test.compatEnvVar)
cla := proto.Clone(baseCLA).(*v3endpointpb.ClusterLoadAssignment)
cla.Endpoints[0].LbEndpoints[0].Metadata = test.metadata
marshalledCLA := testutils.MarshalAny(t, cla)
_, update, err := unmarshalEndpointsResource(marshalledCLA)
if err != nil {
t.Fatalf("unmarshalEndpointsResource() got error = %v, want success", err)
}
got := update.Localities[0].Endpoints[0].HashKey
if got != test.wantHashKey {
t.Errorf("unmarshalEndpointResource() endpoint hash key: got %s, want %s", got, test.wantHashKey)
}
})
}
}
func (s) TestUnmarshalEndpoints(t *testing.T) {
var v3EndpointsAny = testutils.MarshalAny(t, func() *v3endpointpb.ClusterLoadAssignment {
clab0 := newClaBuilder("test", nil)