xds: handle weighted cluster as route action (#3613)

This commit is contained in:
Menghan Li 2020-05-26 13:58:04 -07:00 committed by GitHub
parent 4709b05f2c
commit d071d56834
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 499 additions and 122 deletions

View File

@ -25,6 +25,7 @@ import (
"time" "time"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/serviceconfig"
) )
var ( var (
@ -46,6 +47,11 @@ var (
// ParseServiceConfigForTesting is for creating a fake // ParseServiceConfigForTesting is for creating a fake
// ClientConn for resolver testing only // ClientConn for resolver testing only
ParseServiceConfigForTesting interface{} // func(string) *serviceconfig.ParseResult ParseServiceConfigForTesting interface{} // func(string) *serviceconfig.ParseResult
// EqualServiceConfigForTesting is for testing service config generation and
// parsing. Both a and b should be returned by ParseServiceConfigForTesting.
// This function compares the config without rawJSON stripped, in case the
// there's difference in white space.
EqualServiceConfigForTesting func(a, b serviceconfig.Config) bool
) )
// HealthChecker defines the signature of the client-side LB channel health checking function. // HealthChecker defines the signature of the client-side LB channel health checking function.

View File

@ -21,6 +21,7 @@ package grpc
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -400,3 +401,34 @@ func getMaxSize(mcMax, doptMax *int, defaultVal int) *int {
func newInt(b int) *int { func newInt(b int) *int {
return &b return &b
} }
func init() {
internal.EqualServiceConfigForTesting = equalServiceConfig
}
// equalServiceConfig compares two configs. The rawJSONString field is ignored,
// because they may diff in white spaces.
//
// If any of them is NOT *ServiceConfig, return false.
func equalServiceConfig(a, b serviceconfig.Config) bool {
aa, ok := a.(*ServiceConfig)
if !ok {
return false
}
bb, ok := b.(*ServiceConfig)
if !ok {
return false
}
aaRaw := aa.rawJSONString
aa.rawJSONString = ""
bbRaw := bb.rawJSONString
bb.rawJSONString = ""
defer func() {
aa.rawJSONString = aaRaw
bb.rawJSONString = bbRaw
}()
// Using reflect.DeepEqual instead of cmp.Equal because many balancer
// configs are unexported, and cmp.Equal cannot compare unexported fields
// from unexported structs.
return reflect.DeepEqual(aa, bb)
}

View File

@ -22,4 +22,5 @@ package balancer
import ( import (
_ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // Register the CDS balancer _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // Register the CDS balancer
_ "google.golang.org/grpc/xds/internal/balancer/edsbalancer" // Register the EDS balancer _ "google.golang.org/grpc/xds/internal/balancer/edsbalancer" // Register the EDS balancer
_ "google.golang.org/grpc/xds/internal/balancer/weightedtarget" // Register the weighted_target balancer
) )

View File

@ -23,7 +23,7 @@ import (
) )
type rdsUpdate struct { type rdsUpdate struct {
clusterName string weightedCluster map[string]uint32
} }
type rdsCallbackFunc func(rdsUpdate, error) type rdsCallbackFunc func(rdsUpdate, error)

View File

@ -21,6 +21,7 @@ package client
import ( import (
"testing" "testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils"
) )
@ -50,12 +51,12 @@ func (s) TestRDSWatch(t *testing.T) {
rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err})
}) })
wantUpdate := rdsUpdate{clusterName: testCDSName} wantUpdate := rdsUpdate{weightedCluster: map[string]uint32{testCDSName: 1}}
v2Client.r.newRDSUpdate(map[string]rdsUpdate{ v2Client.r.newRDSUpdate(map[string]rdsUpdate{
testRDSName: wantUpdate, testRDSName: wantUpdate,
}) })
if u, err := rdsUpdateCh.Receive(); err != nil || u != (rdsUpdateErr{wantUpdate, nil}) { if u, err := rdsUpdateCh.Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdate{}, rdsUpdateErr{})) {
t.Errorf("unexpected rdsUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected rdsUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -106,13 +107,13 @@ func (s) TestRDSTwoWatchSameResourceName(t *testing.T) {
}) })
} }
wantUpdate := rdsUpdate{clusterName: testCDSName} wantUpdate := rdsUpdate{weightedCluster: map[string]uint32{testCDSName: 1}}
v2Client.r.newRDSUpdate(map[string]rdsUpdate{ v2Client.r.newRDSUpdate(map[string]rdsUpdate{
testRDSName: wantUpdate, testRDSName: wantUpdate,
}) })
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if u, err := rdsUpdateChs[i].Receive(); err != nil || u != (rdsUpdateErr{wantUpdate, nil}) { if u, err := rdsUpdateChs[i].Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdate{}, rdsUpdateErr{})) {
t.Errorf("i=%v, unexpected rdsUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected rdsUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
@ -124,7 +125,7 @@ func (s) TestRDSTwoWatchSameResourceName(t *testing.T) {
}) })
for i := 0; i < count-1; i++ { for i := 0; i < count-1; i++ {
if u, err := rdsUpdateChs[i].Receive(); err != nil || u != (rdsUpdateErr{wantUpdate, nil}) { if u, err := rdsUpdateChs[i].Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdate{}, rdsUpdateErr{})) {
t.Errorf("i=%v, unexpected rdsUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected rdsUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
@ -166,20 +167,20 @@ func (s) TestRDSThreeWatchDifferentResourceName(t *testing.T) {
rdsUpdateCh2.Send(rdsUpdateErr{u: update, err: err}) rdsUpdateCh2.Send(rdsUpdateErr{u: update, err: err})
}) })
wantUpdate1 := rdsUpdate{clusterName: testCDSName + "1"} wantUpdate1 := rdsUpdate{weightedCluster: map[string]uint32{testCDSName + "1": 1}}
wantUpdate2 := rdsUpdate{clusterName: testCDSName + "2"} wantUpdate2 := rdsUpdate{weightedCluster: map[string]uint32{testCDSName + "2": 1}}
v2Client.r.newRDSUpdate(map[string]rdsUpdate{ v2Client.r.newRDSUpdate(map[string]rdsUpdate{
testRDSName + "1": wantUpdate1, testRDSName + "1": wantUpdate1,
testRDSName + "2": wantUpdate2, testRDSName + "2": wantUpdate2,
}) })
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if u, err := rdsUpdateChs[i].Receive(); err != nil || u != (rdsUpdateErr{wantUpdate1, nil}) { if u, err := rdsUpdateChs[i].Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate1, nil}, cmp.AllowUnexported(rdsUpdate{}, rdsUpdateErr{})) {
t.Errorf("i=%v, unexpected rdsUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected rdsUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
if u, err := rdsUpdateCh2.Receive(); err != nil || u != (rdsUpdateErr{wantUpdate2, nil}) { if u, err := rdsUpdateCh2.Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate2, nil}, cmp.AllowUnexported(rdsUpdate{}, rdsUpdateErr{})) {
t.Errorf("unexpected rdsUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected rdsUpdate: %v, error receiving from channel: %v", u, err)
} }
} }
@ -203,12 +204,12 @@ func (s) TestRDSWatchAfterCache(t *testing.T) {
rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err})
}) })
wantUpdate := rdsUpdate{clusterName: testCDSName} wantUpdate := rdsUpdate{weightedCluster: map[string]uint32{testCDSName: 1}}
v2Client.r.newRDSUpdate(map[string]rdsUpdate{ v2Client.r.newRDSUpdate(map[string]rdsUpdate{
testRDSName: wantUpdate, testRDSName: wantUpdate,
}) })
if u, err := rdsUpdateCh.Receive(); err != nil || u != (rdsUpdateErr{wantUpdate, nil}) { if u, err := rdsUpdateCh.Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdate{}, rdsUpdateErr{})) {
t.Errorf("unexpected rdsUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected rdsUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -219,7 +220,7 @@ func (s) TestRDSWatchAfterCache(t *testing.T) {
}) })
// New watch should receives the update. // New watch should receives the update.
if u, err := rdsUpdateCh2.Receive(); err != nil || u != (rdsUpdateErr{wantUpdate, nil}) { if u, err := rdsUpdateCh2.Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdate{}, rdsUpdateErr{})) {
t.Errorf("unexpected rdsUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected rdsUpdate: %v, error receiving from channel: %v", u, err)
} }

View File

@ -25,7 +25,9 @@ import (
// ServiceUpdate contains update about the service. // ServiceUpdate contains update about the service.
type ServiceUpdate struct { type ServiceUpdate struct {
Cluster string // WeightedCluster is a map from cluster names (CDS resource to watch) to
// their weights.
WeightedCluster map[string]uint32
} }
// WatchService uses LDS and RDS to discover information about the provided // WatchService uses LDS and RDS to discover information about the provided
@ -106,7 +108,9 @@ func (w *serviceUpdateWatcher) handleRDSResp(update rdsUpdate, err error) {
w.serviceCb(ServiceUpdate{}, err) w.serviceCb(ServiceUpdate{}, err)
return return
} }
w.serviceCb(ServiceUpdate{Cluster: update.clusterName}, nil) w.serviceCb(ServiceUpdate{
WeightedCluster: update.weightedCluster,
}, nil)
} }
func (w *serviceUpdateWatcher) close() { func (w *serviceUpdateWatcher) close() {

View File

@ -24,6 +24,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils"
"google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/testutils/fakeserver"
) )
@ -54,7 +55,7 @@ func (s) TestServiceWatch(t *testing.T) {
serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err}) serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err})
}) })
wantUpdate := ServiceUpdate{Cluster: testCDSName} wantUpdate := ServiceUpdate{WeightedCluster: map[string]uint32{testCDSName: 1}}
<-v2Client.addWatches[ldsURL] <-v2Client.addWatches[ldsURL]
v2Client.r.newLDSUpdate(map[string]ldsUpdate{ v2Client.r.newLDSUpdate(map[string]ldsUpdate{
@ -62,10 +63,10 @@ func (s) TestServiceWatch(t *testing.T) {
}) })
<-v2Client.addWatches[rdsURL] <-v2Client.addWatches[rdsURL]
v2Client.r.newRDSUpdate(map[string]rdsUpdate{ v2Client.r.newRDSUpdate(map[string]rdsUpdate{
testRDSName: {clusterName: testCDSName}, testRDSName: {weightedCluster: map[string]uint32{testCDSName: 1}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || u != (serviceUpdateErr{wantUpdate, nil}) { if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(serviceUpdateErr{})) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
} }
@ -90,7 +91,7 @@ func (s) TestServiceWatchLDSUpdate(t *testing.T) {
serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err}) serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err})
}) })
wantUpdate := ServiceUpdate{Cluster: testCDSName} wantUpdate := ServiceUpdate{WeightedCluster: map[string]uint32{testCDSName: 1}}
<-v2Client.addWatches[ldsURL] <-v2Client.addWatches[ldsURL]
v2Client.r.newLDSUpdate(map[string]ldsUpdate{ v2Client.r.newLDSUpdate(map[string]ldsUpdate{
@ -98,10 +99,10 @@ func (s) TestServiceWatchLDSUpdate(t *testing.T) {
}) })
<-v2Client.addWatches[rdsURL] <-v2Client.addWatches[rdsURL]
v2Client.r.newRDSUpdate(map[string]rdsUpdate{ v2Client.r.newRDSUpdate(map[string]rdsUpdate{
testRDSName: {clusterName: testCDSName}, testRDSName: {weightedCluster: map[string]uint32{testCDSName: 1}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || u != (serviceUpdateErr{wantUpdate, nil}) { if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(serviceUpdateErr{})) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -113,20 +114,20 @@ func (s) TestServiceWatchLDSUpdate(t *testing.T) {
// Another update for the old name. // Another update for the old name.
v2Client.r.newRDSUpdate(map[string]rdsUpdate{ v2Client.r.newRDSUpdate(map[string]rdsUpdate{
testRDSName: {clusterName: testCDSName}, testRDSName: {weightedCluster: map[string]uint32{testCDSName: 1}},
}) })
if u, err := serviceUpdateCh.Receive(); err != testutils.ErrRecvTimeout { if u, err := serviceUpdateCh.Receive(); err != testutils.ErrRecvTimeout {
t.Errorf("unexpected serviceUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected serviceUpdate: %v, %v, want channel recv timeout", u, err)
} }
wantUpdate2 := ServiceUpdate{Cluster: testCDSName + "2"} wantUpdate2 := ServiceUpdate{WeightedCluster: map[string]uint32{testCDSName + "2": 1}}
// RDS update for the new name. // RDS update for the new name.
v2Client.r.newRDSUpdate(map[string]rdsUpdate{ v2Client.r.newRDSUpdate(map[string]rdsUpdate{
testRDSName + "2": {clusterName: testCDSName + "2"}, testRDSName + "2": {weightedCluster: map[string]uint32{testCDSName + "2": 1}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || u != (serviceUpdateErr{wantUpdate2, nil}) { if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate2, nil}, cmp.AllowUnexported(serviceUpdateErr{})) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
} }
@ -151,7 +152,7 @@ func (s) TestServiceWatchSecond(t *testing.T) {
serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err}) serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err})
}) })
wantUpdate := ServiceUpdate{Cluster: testCDSName} wantUpdate := ServiceUpdate{WeightedCluster: map[string]uint32{testCDSName: 1}}
<-v2Client.addWatches[ldsURL] <-v2Client.addWatches[ldsURL]
v2Client.r.newLDSUpdate(map[string]ldsUpdate{ v2Client.r.newLDSUpdate(map[string]ldsUpdate{
@ -159,10 +160,10 @@ func (s) TestServiceWatchSecond(t *testing.T) {
}) })
<-v2Client.addWatches[rdsURL] <-v2Client.addWatches[rdsURL]
v2Client.r.newRDSUpdate(map[string]rdsUpdate{ v2Client.r.newRDSUpdate(map[string]rdsUpdate{
testRDSName: {clusterName: testCDSName}, testRDSName: {weightedCluster: map[string]uint32{testCDSName: 1}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || u != (serviceUpdateErr{wantUpdate, nil}) { if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(serviceUpdateErr{})) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -177,7 +178,7 @@ func (s) TestServiceWatchSecond(t *testing.T) {
t.Fatalf("failed to get serviceUpdate: %v", err) t.Fatalf("failed to get serviceUpdate: %v", err)
} }
uu := u.(serviceUpdateErr) uu := u.(serviceUpdateErr)
if uu.u != (ServiceUpdate{}) { if !cmp.Equal(uu.u, ServiceUpdate{}) {
t.Errorf("unexpected serviceUpdate: %v, want %v", uu.u, ServiceUpdate{}) t.Errorf("unexpected serviceUpdate: %v, want %v", uu.u, ServiceUpdate{})
} }
if uu.err == nil { if uu.err == nil {
@ -190,10 +191,10 @@ func (s) TestServiceWatchSecond(t *testing.T) {
testLDSName: {routeName: testRDSName}, testLDSName: {routeName: testRDSName},
}) })
v2Client.r.newRDSUpdate(map[string]rdsUpdate{ v2Client.r.newRDSUpdate(map[string]rdsUpdate{
testRDSName: {clusterName: testCDSName}, testRDSName: {weightedCluster: map[string]uint32{testCDSName: 1}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || u != (serviceUpdateErr{wantUpdate, nil}) { if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(serviceUpdateErr{})) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -227,8 +228,8 @@ func (s) TestServiceWatchWithNoResponseFromServer(t *testing.T) {
callbackCh := testutils.NewChannel() callbackCh := testutils.NewChannel()
cancelWatch := xdsClient.WatchService(goodLDSTarget1, func(su ServiceUpdate, err error) { cancelWatch := xdsClient.WatchService(goodLDSTarget1, func(su ServiceUpdate, err error) {
if su.Cluster != "" { if su.WeightedCluster != nil {
callbackCh.Send(fmt.Errorf("got clusterName: %+v, want empty clusterName", su.Cluster)) callbackCh.Send(fmt.Errorf("got WeightedCluster: %+v, want nil", su.WeightedCluster))
return return
} }
if err == nil { if err == nil {
@ -271,8 +272,8 @@ func (s) TestServiceWatchEmptyRDS(t *testing.T) {
callbackCh := testutils.NewChannel() callbackCh := testutils.NewChannel()
cancelWatch := xdsClient.WatchService(goodLDSTarget1, func(su ServiceUpdate, err error) { cancelWatch := xdsClient.WatchService(goodLDSTarget1, func(su ServiceUpdate, err error) {
if su.Cluster != "" { if su.WeightedCluster != nil {
callbackCh.Send(fmt.Errorf("got clusterName: %+v, want empty clusterName", su.Cluster)) callbackCh.Send(fmt.Errorf("got WeightedCluster: %+v, want nil", su.WeightedCluster))
return return
} }
if err == nil { if err == nil {

View File

@ -48,26 +48,25 @@ func (v2c *v2Client) handleRDSResponse(resp *xdspb.DiscoveryResponse) error {
v2c.logger.Infof("Resource with name: %v, type: %T, contains: %v. Picking routes for current watching hostname %v", rc.GetName(), rc, rc, v2c.hostname) v2c.logger.Infof("Resource with name: %v, type: %T, contains: %v. Picking routes for current watching hostname %v", rc.GetName(), rc, rc, v2c.hostname)
// Use the hostname (resourceName for LDS) to find the routes. // Use the hostname (resourceName for LDS) to find the routes.
cluster, err := getClusterFromRouteConfiguration(rc, hostname) u, err := generateRDSUpdateFromRouteConfiguration(rc, hostname)
if cluster == "" { if err != nil {
return fmt.Errorf("xds: received invalid RouteConfiguration in RDS response: %+v with err: %v", rc, err) return fmt.Errorf("xds: received invalid RouteConfiguration in RDS response: %+v with err: %v", rc, err)
} }
// If we get here, it means that this resource was a good one. // If we get here, it means that this resource was a good one.
returnUpdate[rc.GetName()] = rdsUpdate{clusterName: cluster} returnUpdate[rc.GetName()] = u
} }
v2c.parent.newRDSUpdate(returnUpdate) v2c.parent.newRDSUpdate(returnUpdate)
return nil return nil
} }
// getClusterFromRouteConfiguration checks if the provided RouteConfiguration // generateRDSUpdateFromRouteConfiguration checks if the provided
// meets the expected criteria. If so, it returns a non-empty clusterName with // RouteConfiguration meets the expected criteria. If so, it returns a rdsUpdate
// nil error. // with nil error.
// //
// A RouteConfiguration resource is considered valid when only if it contains a // A RouteConfiguration resource is considered valid when only if it contains a
// VirtualHost whose domain field matches the server name from the URI passed // VirtualHost whose domain field matches the server name from the URI passed
// to the gRPC channel, and it contains a clusterName. // to the gRPC channel, and it contains a clusterName or a weighted cluster.
// //
// The RouteConfiguration includes a list of VirtualHosts, which may have zero // The RouteConfiguration includes a list of VirtualHosts, which may have zero
// or more elements. We are interested in the element whose domains field // or more elements. We are interested in the element whose domains field
@ -75,8 +74,9 @@ func (v2c *v2Client) handleRDSResponse(resp *xdspb.DiscoveryResponse) error {
// VirtualHost proto that the we are interested in is the list of routes. We // VirtualHost proto that the we are interested in is the list of routes. We
// only look at the last route in the list (the default route), whose match // only look at the last route in the list (the default route), whose match
// field must be empty and whose route field must be set. Inside that route // field must be empty and whose route field must be set. Inside that route
// message, the cluster field will contain the clusterName we are looking for. // message, the cluster field will contain the clusterName or weighted clusters
func getClusterFromRouteConfiguration(rc *xdspb.RouteConfiguration, host string) (string, error) { // we are looking for.
func generateRDSUpdateFromRouteConfiguration(rc *xdspb.RouteConfiguration, host string) (rdsUpdate, error) {
// //
// Currently this returns "" on error, and the caller will return an error. // Currently this returns "" on error, and the caller will return an error.
// But the error doesn't contain details of why the response is invalid // But the error doesn't contain details of why the response is invalid
@ -87,31 +87,66 @@ func getClusterFromRouteConfiguration(rc *xdspb.RouteConfiguration, host string)
vh := findBestMatchingVirtualHost(host, rc.GetVirtualHosts()) vh := findBestMatchingVirtualHost(host, rc.GetVirtualHosts())
if vh == nil { if vh == nil {
// No matching virtual host found. // No matching virtual host found.
return "", fmt.Errorf("no matching virtual host found") return rdsUpdate{}, fmt.Errorf("no matching virtual host found")
} }
if len(vh.Routes) == 0 { if len(vh.Routes) == 0 {
// The matched virtual host has no routes, this is invalid because there // The matched virtual host has no routes, this is invalid because there
// should be at least one default route. // should be at least one default route.
return "", fmt.Errorf("matched virtual host has no routes") return rdsUpdate{}, fmt.Errorf("matched virtual host has no routes")
} }
dr := vh.Routes[len(vh.Routes)-1] dr := vh.Routes[len(vh.Routes)-1]
match := dr.GetMatch() match := dr.GetMatch()
if match == nil { if match == nil {
return "", fmt.Errorf("matched virtual host's default route doesn't have a match") return rdsUpdate{}, fmt.Errorf("matched virtual host's default route doesn't have a match")
} }
if prefix := match.GetPrefix(); prefix != "" && prefix != "/" { if prefix := match.GetPrefix(); prefix != "" && prefix != "/" {
// The matched virtual host is invalid. Match is not "" or "/". // The matched virtual host is invalid. Match is not "" or "/".
return "", fmt.Errorf("matched virtual host's default route is %v, want Prefix empty string or /", match) return rdsUpdate{}, fmt.Errorf("matched virtual host's default route is %v, want Prefix empty string or /", match)
} }
if caseSensitive := match.GetCaseSensitive(); caseSensitive != nil && !caseSensitive.Value { if caseSensitive := match.GetCaseSensitive(); caseSensitive != nil && !caseSensitive.Value {
// The case sensitive is set to false. Not set or set to true are both // The case sensitive is set to false. Not set or set to true are both
// valid. // valid.
return "", fmt.Errorf("matches virtual host's default route set case-sensitive to false") return rdsUpdate{}, fmt.Errorf("matched virtual host's default route set case-sensitive to false")
} }
if route := dr.GetRoute(); route != nil { route := dr.GetRoute()
return route.GetCluster(), nil if route == nil {
return rdsUpdate{}, fmt.Errorf("matched route is nil")
} }
return "", fmt.Errorf("matched route is nil")
if wc := route.GetWeightedClusters(); wc != nil {
m, err := weightedClustersProtoToMap(wc)
if err != nil {
return rdsUpdate{}, fmt.Errorf("matched weighted cluster is invalid: %v", err)
}
return rdsUpdate{weightedCluster: m}, nil
}
// When there's just one cluster, we set weightedCluster to map with one
// entry. This mean we will build a weighted_target balancer even if there's
// just one cluster.
//
// Otherwise, we will need to switch the top policy between weighted_target
// and CDS. In case when the action changes between one cluster and multiple
// clusters, changing top level policy means recreating TCP connection every
// time.
return rdsUpdate{weightedCluster: map[string]uint32{route.GetCluster(): 1}}, nil
}
func weightedClustersProtoToMap(wc *routepb.WeightedCluster) (map[string]uint32, error) {
ret := make(map[string]uint32)
var totalWeight uint32 = 100
if t := wc.GetTotalWeight().GetValue(); t != 0 {
totalWeight = t
}
for _, cw := range wc.Clusters {
w := cw.Weight.GetValue()
ret[cw.Name] = w
totalWeight -= w
}
if totalWeight != 0 {
return nil, fmt.Errorf("weights of clusters do not add up to total total weight, difference: %v", totalWeight)
}
return ret, nil
} }
type domainMatchType int type domainMatchType int

View File

@ -30,23 +30,21 @@ import (
"google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/testutils/fakeserver"
) )
func (s) TestRDSGetClusterFromRouteConfiguration(t *testing.T) { func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
rc *xdspb.RouteConfiguration rc *xdspb.RouteConfiguration
wantCluster string wantUpdate rdsUpdate
wantError bool wantError bool
}{ }{
{ {
name: "no-virtual-hosts-in-rc", name: "no-virtual-hosts-in-rc",
rc: emptyRouteConfig, rc: emptyRouteConfig,
wantCluster: "",
wantError: true, wantError: true,
}, },
{ {
name: "no-domains-in-rc", name: "no-domains-in-rc",
rc: noDomainsInRouteConfig, rc: noDomainsInRouteConfig,
wantCluster: "",
wantError: true, wantError: true,
}, },
{ {
@ -56,7 +54,6 @@ func (s) TestRDSGetClusterFromRouteConfiguration(t *testing.T) {
{Domains: []string{uninterestingDomain}}, {Domains: []string{uninterestingDomain}},
}, },
}, },
wantCluster: "",
wantError: true, wantError: true,
}, },
{ {
@ -66,7 +63,6 @@ func (s) TestRDSGetClusterFromRouteConfiguration(t *testing.T) {
{Domains: []string{goodLDSTarget1}}, {Domains: []string{goodLDSTarget1}},
}, },
}, },
wantCluster: "",
wantError: true, wantError: true,
}, },
{ {
@ -87,7 +83,6 @@ func (s) TestRDSGetClusterFromRouteConfiguration(t *testing.T) {
}, },
}, },
}, },
wantCluster: "",
wantError: true, wantError: true,
}, },
{ {
@ -105,7 +100,6 @@ func (s) TestRDSGetClusterFromRouteConfiguration(t *testing.T) {
}, },
}, },
}, },
wantCluster: "",
wantError: true, wantError: true,
}, },
{ {
@ -118,7 +112,6 @@ func (s) TestRDSGetClusterFromRouteConfiguration(t *testing.T) {
}, },
}, },
}, },
wantCluster: "",
wantError: true, wantError: true,
}, },
{ {
@ -139,7 +132,6 @@ func (s) TestRDSGetClusterFromRouteConfiguration(t *testing.T) {
}, },
}, },
}, },
wantCluster: "",
wantError: true, wantError: true,
}, },
{ {
@ -158,14 +150,12 @@ func (s) TestRDSGetClusterFromRouteConfiguration(t *testing.T) {
Route: &routepb.RouteAction{ Route: &routepb.RouteAction{
ClusterSpecifier: &routepb.RouteAction_Cluster{Cluster: goodClusterName1}, ClusterSpecifier: &routepb.RouteAction_Cluster{Cluster: goodClusterName1},
}}}}}}}, }}}}}}},
wantCluster: "",
wantError: true, wantError: true,
}, },
{ {
name: "good-route-config-with-empty-string-route", name: "good-route-config-with-empty-string-route",
rc: goodRouteConfig1, rc: goodRouteConfig1,
wantCluster: goodClusterName1, wantUpdate: rdsUpdate{weightedCluster: map[string]uint32{goodClusterName1: 1}},
wantError: false,
}, },
{ {
// default route's match is not empty string, but "/". // default route's match is not empty string, but "/".
@ -180,15 +170,59 @@ func (s) TestRDSGetClusterFromRouteConfiguration(t *testing.T) {
Route: &routepb.RouteAction{ Route: &routepb.RouteAction{
ClusterSpecifier: &routepb.RouteAction_Cluster{Cluster: goodClusterName1}, ClusterSpecifier: &routepb.RouteAction_Cluster{Cluster: goodClusterName1},
}}}}}}}, }}}}}}},
wantCluster: goodClusterName1, wantUpdate: rdsUpdate{weightedCluster: map[string]uint32{goodClusterName1: 1}},
},
{
// weights not add up to total-weight.
name: "route-config-with-weighted_clusters_weights_not_add_up",
rc: &xdspb.RouteConfiguration{
Name: goodRouteName1,
VirtualHosts: []*routepb.VirtualHost{{
Domains: []string{goodLDSTarget1},
Routes: []*routepb.Route{{
Match: &routepb.RouteMatch{PathSpecifier: &routepb.RouteMatch_Prefix{Prefix: "/"}},
Action: &routepb.Route_Route{
Route: &routepb.RouteAction{
ClusterSpecifier: &routepb.RouteAction_WeightedClusters{
WeightedClusters: &routepb.WeightedCluster{
Clusters: []*routepb.WeightedCluster_ClusterWeight{
{Name: "a", Weight: &wrapperspb.UInt32Value{Value: 2}},
{Name: "b", Weight: &wrapperspb.UInt32Value{Value: 3}},
{Name: "c", Weight: &wrapperspb.UInt32Value{Value: 5}},
},
TotalWeight: &wrapperspb.UInt32Value{Value: 30},
}}}}}}}}},
wantError: true,
},
{
name: "good-route-config-with-weighted_clusters",
rc: &xdspb.RouteConfiguration{
Name: goodRouteName1,
VirtualHosts: []*routepb.VirtualHost{{
Domains: []string{goodLDSTarget1},
Routes: []*routepb.Route{{
Match: &routepb.RouteMatch{PathSpecifier: &routepb.RouteMatch_Prefix{Prefix: "/"}},
Action: &routepb.Route_Route{
Route: &routepb.RouteAction{
ClusterSpecifier: &routepb.RouteAction_WeightedClusters{
WeightedClusters: &routepb.WeightedCluster{
Clusters: []*routepb.WeightedCluster_ClusterWeight{
{Name: "a", Weight: &wrapperspb.UInt32Value{Value: 2}},
{Name: "b", Weight: &wrapperspb.UInt32Value{Value: 3}},
{Name: "c", Weight: &wrapperspb.UInt32Value{Value: 5}},
},
TotalWeight: &wrapperspb.UInt32Value{Value: 10},
}}}}}}}}},
wantUpdate: rdsUpdate{weightedCluster: map[string]uint32{"a": 2, "b": 3, "c": 5}},
}, },
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
gotCluster, gotError := getClusterFromRouteConfiguration(test.rc, goodLDSTarget1) gotUpdate, gotError := generateRDSUpdateFromRouteConfiguration(test.rc, goodLDSTarget1)
if gotCluster != test.wantCluster || (gotError != nil) != test.wantError { if !cmp.Equal(gotUpdate, test.wantUpdate, cmp.AllowUnexported(rdsUpdate{})) || (gotError != nil) != test.wantError {
t.Errorf("getClusterFromRouteConfiguration(%+v, %v) = %v, want %v", test.rc, goodLDSTarget1, gotCluster, test.wantCluster) t.Errorf("generateRDSUpdateFromRouteConfiguration(%+v, %v) = %v, want %v", test.rc, goodLDSTarget1, gotUpdate, test.wantUpdate)
} }
}) })
} }
@ -256,7 +290,7 @@ func (s) TestRDSHandleResponse(t *testing.T) {
name: "one-good-route-config", name: "one-good-route-config",
rdsResponse: goodRDSResponse1, rdsResponse: goodRDSResponse1,
wantErr: false, wantErr: false,
wantUpdate: &rdsUpdate{clusterName: goodClusterName1}, wantUpdate: &rdsUpdate{weightedCluster: map[string]uint32{goodClusterName1: 1}},
wantUpdateErr: false, wantUpdateErr: false,
}, },
} }
@ -412,3 +446,73 @@ func (s) TestFindBestMatchingVirtualHost(t *testing.T) {
}) })
} }
} }
func (s) TestWeightedClustersProtoToMap(t *testing.T) {
tests := []struct {
name string
wc *routepb.WeightedCluster
want map[string]uint32
wantErr bool
}{
{
name: "weight not add up to non default total",
wc: &routepb.WeightedCluster{
Clusters: []*routepb.WeightedCluster_ClusterWeight{
{Name: "a", Weight: &wrapperspb.UInt32Value{Value: 1}},
{Name: "b", Weight: &wrapperspb.UInt32Value{Value: 1}},
{Name: "c", Weight: &wrapperspb.UInt32Value{Value: 1}},
},
TotalWeight: &wrapperspb.UInt32Value{Value: 10},
},
wantErr: true,
},
{
name: "weight not add up to default total",
wc: &routepb.WeightedCluster{
Clusters: []*routepb.WeightedCluster_ClusterWeight{
{Name: "a", Weight: &wrapperspb.UInt32Value{Value: 2}},
{Name: "b", Weight: &wrapperspb.UInt32Value{Value: 3}},
{Name: "c", Weight: &wrapperspb.UInt32Value{Value: 5}},
},
TotalWeight: nil,
},
wantErr: true,
},
{
name: "ok non default total weight",
wc: &routepb.WeightedCluster{
Clusters: []*routepb.WeightedCluster_ClusterWeight{
{Name: "a", Weight: &wrapperspb.UInt32Value{Value: 2}},
{Name: "b", Weight: &wrapperspb.UInt32Value{Value: 3}},
{Name: "c", Weight: &wrapperspb.UInt32Value{Value: 5}},
},
TotalWeight: &wrapperspb.UInt32Value{Value: 10},
},
want: map[string]uint32{"a": 2, "b": 3, "c": 5},
},
{
name: "ok default total weight is 100",
wc: &routepb.WeightedCluster{
Clusters: []*routepb.WeightedCluster_ClusterWeight{
{Name: "a", Weight: &wrapperspb.UInt32Value{Value: 20}},
{Name: "b", Weight: &wrapperspb.UInt32Value{Value: 30}},
{Name: "c", Weight: &wrapperspb.UInt32Value{Value: 50}},
},
TotalWeight: nil,
},
want: map[string]uint32{"a": 20, "b": 30, "c": 50},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := weightedClustersProtoToMap(tt.wc)
if (err != nil) != tt.wantErr {
t.Errorf("weightedClustersProtoToMap() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !cmp.Equal(got, tt.want) {
t.Errorf("weightedClustersProtoToMap() got = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,80 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package resolver
import (
"encoding/json"
"fmt"
xdsclient "google.golang.org/grpc/xds/internal/client"
)
const (
cdsName = "cds_experimental"
weightedTargetName = "weighted_target_experimental"
)
type serviceConfig struct {
LoadBalancingConfig balancerConfig `json:"loadBalancingConfig"`
}
type balancerConfig []map[string]interface{}
func newBalancerConfig(name string, config interface{}) balancerConfig {
return []map[string]interface{}{{name: config}}
}
type weightedCDSBalancerConfig struct {
Targets map[string]cdsWithWeight `json:"targets"`
}
type cdsWithWeight struct {
Weight uint32 `json:"weight"`
ChildPolicy balancerConfig `json:"childPolicy"`
}
type cdsBalancerConfig struct {
Cluster string `json:"cluster"`
}
func serviceUpdateToJSON(su xdsclient.ServiceUpdate) (string, error) {
// Even if WeightedCluster has only one entry, we still use weighted_target
// as top level balancer, to avoid switching top policy between CDS and
// weighted_target, causing TCP connection to be recreated.
targets := make(map[string]cdsWithWeight)
for name, weight := range su.WeightedCluster {
targets[name] = cdsWithWeight{
Weight: weight,
ChildPolicy: newBalancerConfig(cdsName, cdsBalancerConfig{Cluster: name}),
}
}
sc := serviceConfig{
LoadBalancingConfig: newBalancerConfig(
weightedTargetName, weightedCDSBalancerConfig{
Targets: targets,
},
),
}
bs, err := json.Marshal(sc)
if err != nil {
return "", fmt.Errorf("failed to marshal json: %v", err)
}
return string(bs), nil
}

View File

@ -0,0 +1,92 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package resolver
import (
"testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/serviceconfig"
_ "google.golang.org/grpc/xds/internal/balancer/weightedtarget"
"google.golang.org/grpc/xds/internal/client"
)
const (
testCluster1 = "test-cluster-1"
testClusterOnlyJSON = `{"loadBalancingConfig":[{
"weighted_target_experimental": {
"targets": { "test-cluster-1" : { "weight":1, "childPolicy":[{"cds_experimental":{"cluster":"test-cluster-1"}}] } }
}
}]}`
testWeightedCDSJSON = `{"loadBalancingConfig":[{
"weighted_target_experimental": {
"targets": {
"cluster_1" : {
"weight":75,
"childPolicy":[{"cds_experimental":{"cluster":"cluster_1"}}]
},
"cluster_2" : {
"weight":25,
"childPolicy":[{"cds_experimental":{"cluster":"cluster_2"}}]
}
}
}
}]}`
)
func TestServiceUpdateToJSON(t *testing.T) {
tests := []struct {
name string
su client.ServiceUpdate
wantJSON string // wantJSON is not to be compared verbatim.
}{
{
name: "one cluster only",
su: client.ServiceUpdate{WeightedCluster: map[string]uint32{testCluster1: 1}},
wantJSON: testClusterOnlyJSON,
},
{
name: "weighted clusters",
su: client.ServiceUpdate{WeightedCluster: map[string]uint32{
"cluster_1": 75,
"cluster_2": 25,
}},
wantJSON: testWeightedCDSJSON,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotJSON, err := serviceUpdateToJSON(tt.su)
if err != nil {
t.Errorf("serviceUpdateToJSON returned error: %v", err)
return
}
gotParsed := internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)(gotJSON)
wantParsed := internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)(tt.wantJSON)
if !internal.EqualServiceConfigForTesting(gotParsed.Config, wantParsed.Config) {
t.Errorf("serviceUpdateToJSON() = %v, want %v", gotJSON, tt.wantJSON)
t.Error("gotParsed: ", cmp.Diff(nil, gotParsed))
t.Error("wantParsed: ", cmp.Diff(nil, wantParsed))
}
})
}
}

View File

@ -161,16 +161,6 @@ type xdsResolver struct {
cancelWatch func() cancelWatch func()
} }
const jsonFormatSC = `{
"loadBalancingConfig":[
{
"cds_experimental":{
"Cluster": "%s"
}
}
]
}`
// run is a long running goroutine which blocks on receiving service updates // run is a long running goroutine which blocks on receiving service updates
// and passes it on the ClientConn. // and passes it on the ClientConn.
func (r *xdsResolver) run() { func (r *xdsResolver) run() {
@ -183,7 +173,12 @@ func (r *xdsResolver) run() {
r.cc.ReportError(update.err) r.cc.ReportError(update.err)
continue continue
} }
sc := fmt.Sprintf(jsonFormatSC, update.su.Cluster) sc, err := serviceUpdateToJSON(update.su)
if err != nil {
r.logger.Warningf("failed to convert update to service config: %v", err)
r.cc.ReportError(err)
continue
}
r.logger.Infof("Received update on resource %v from xds-client %p, generated service config: %v", r.target.Endpoint, r.client, sc) r.logger.Infof("Received update on resource %v from xds-client %p, generated service config: %v", r.target.Endpoint, r.client, sc)
r.cc.UpdateState(resolver.State{ r.cc.UpdateState(resolver.State{
ServiceConfig: r.cc.ParseServiceConfig(sc), ServiceConfig: r.cc.ParseServiceConfig(sc),

View File

@ -25,12 +25,14 @@ import (
"net" "net"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
xdsinternal "google.golang.org/grpc/xds/internal" xdsinternal "google.golang.org/grpc/xds/internal"
_ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // To parse LB config _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // To parse LB config
"google.golang.org/grpc/xds/internal/client"
xdsclient "google.golang.org/grpc/xds/internal/client" xdsclient "google.golang.org/grpc/xds/internal/client"
"google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/client/bootstrap"
"google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils"
@ -273,7 +275,7 @@ func TestXDSResolverWatchCallbackAfterClose(t *testing.T) {
// Call the watchAPI callback after closing the resolver, and make sure no // Call the watchAPI callback after closing the resolver, and make sure no
// update is triggerred on the ClientConn. // update is triggerred on the ClientConn.
xdsR.Close() xdsR.Close()
xdsC.InvokeWatchServiceCallback(cluster, nil) xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{WeightedCluster: map[string]uint32{cluster: 1}}, nil)
if gotVal, gotErr := tcc.stateCh.Receive(); gotErr != testutils.ErrRecvTimeout { if gotVal, gotErr := tcc.stateCh.Receive(); gotErr != testutils.ErrRecvTimeout {
t.Fatalf("ClientConn.UpdateState called after xdsResolver is closed: %v", gotVal) t.Fatalf("ClientConn.UpdateState called after xdsResolver is closed: %v", gotVal)
} }
@ -297,7 +299,7 @@ func TestXDSResolverBadServiceUpdate(t *testing.T) {
// Invoke the watchAPI callback with a bad service update and wait for the // Invoke the watchAPI callback with a bad service update and wait for the
// ReportError method to be called on the ClientConn. // ReportError method to be called on the ClientConn.
suErr := errors.New("bad serviceupdate") suErr := errors.New("bad serviceupdate")
xdsC.InvokeWatchServiceCallback("", suErr) xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{}, suErr)
if gotErrVal, gotErr := tcc.errorCh.Receive(); gotErr != nil || gotErrVal != suErr { if gotErrVal, gotErr := tcc.errorCh.Receive(); gotErr != nil || gotErrVal != suErr {
t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr) t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr)
} }
@ -318,9 +320,25 @@ func TestXDSResolverGoodServiceUpdate(t *testing.T) {
waitForWatchService(t, xdsC, targetStr) waitForWatchService(t, xdsC, targetStr)
for _, tt := range []struct {
su client.ServiceUpdate
wantJSON string
}{
{
su: client.ServiceUpdate{WeightedCluster: map[string]uint32{testCluster1: 1}},
wantJSON: testClusterOnlyJSON,
},
{
su: client.ServiceUpdate{WeightedCluster: map[string]uint32{
"cluster_1": 75,
"cluster_2": 25,
}},
wantJSON: testWeightedCDSJSON,
},
} {
// Invoke the watchAPI callback with a good service update and wait for the // Invoke the watchAPI callback with a good service update and wait for the
// UpdateState method to be called on the ClientConn. // UpdateState method to be called on the ClientConn.
xdsC.InvokeWatchServiceCallback(cluster, nil) xdsC.InvokeWatchServiceCallback(tt.su, nil)
gotState, err := tcc.stateCh.Receive() gotState, err := tcc.stateCh.Receive()
if err != nil { if err != nil {
t.Fatalf("ClientConn.UpdateState returned error: %v", err) t.Fatalf("ClientConn.UpdateState returned error: %v", err)
@ -332,6 +350,14 @@ func TestXDSResolverGoodServiceUpdate(t *testing.T) {
if err := rState.ServiceConfig.Err; err != nil { if err := rState.ServiceConfig.Err; err != nil {
t.Fatalf("ClientConn.UpdateState received error in service config: %v", rState.ServiceConfig.Err) t.Fatalf("ClientConn.UpdateState received error in service config: %v", rState.ServiceConfig.Err)
} }
wantSCParsed := internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)(tt.wantJSON)
if !internal.EqualServiceConfigForTesting(rState.ServiceConfig.Config, wantSCParsed.Config) {
t.Errorf("ClientConn.UpdateState received different service config")
t.Error("got: ", cmp.Diff(nil, rState.ServiceConfig.Config))
t.Error("want: ", cmp.Diff(nil, wantSCParsed.Config))
}
}
} }
// TestXDSResolverUpdates tests the cases where the resolver gets a good update // TestXDSResolverUpdates tests the cases where the resolver gets a good update
@ -352,14 +378,14 @@ func TestXDSResolverGoodUpdateAfterError(t *testing.T) {
// Invoke the watchAPI callback with a bad service update and wait for the // Invoke the watchAPI callback with a bad service update and wait for the
// ReportError method to be called on the ClientConn. // ReportError method to be called on the ClientConn.
suErr := errors.New("bad serviceupdate") suErr := errors.New("bad serviceupdate")
xdsC.InvokeWatchServiceCallback("", suErr) xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{}, suErr)
if gotErrVal, gotErr := tcc.errorCh.Receive(); gotErr != nil || gotErrVal != suErr { if gotErrVal, gotErr := tcc.errorCh.Receive(); gotErr != nil || gotErrVal != suErr {
t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr) t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr)
} }
// Invoke the watchAPI callback with a good service update and wait for the // Invoke the watchAPI callback with a good service update and wait for the
// UpdateState method to be called on the ClientConn. // UpdateState method to be called on the ClientConn.
xdsC.InvokeWatchServiceCallback(cluster, nil) xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{WeightedCluster: map[string]uint32{cluster: 1}}, nil)
gotState, err := tcc.stateCh.Receive() gotState, err := tcc.stateCh.Receive()
if err != nil { if err != nil {
t.Fatalf("ClientConn.UpdateState returned error: %v", err) t.Fatalf("ClientConn.UpdateState returned error: %v", err)
@ -375,7 +401,7 @@ func TestXDSResolverGoodUpdateAfterError(t *testing.T) {
// Invoke the watchAPI callback with a bad service update and wait for the // Invoke the watchAPI callback with a bad service update and wait for the
// ReportError method to be called on the ClientConn. // ReportError method to be called on the ClientConn.
suErr2 := errors.New("bad serviceupdate 2") suErr2 := errors.New("bad serviceupdate 2")
xdsC.InvokeWatchServiceCallback("", suErr2) xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{}, suErr2)
if gotErrVal, gotErr := tcc.errorCh.Receive(); gotErr != nil || gotErrVal != suErr2 { if gotErrVal, gotErr := tcc.errorCh.Receive(); gotErr != nil || gotErrVal != suErr2 {
t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr2) t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr2)
} }

View File

@ -69,11 +69,11 @@ func (xdsC *Client) WaitForWatchService() (string, error) {
} }
// InvokeWatchServiceCallback invokes the registered service watch callback. // InvokeWatchServiceCallback invokes the registered service watch callback.
func (xdsC *Client) InvokeWatchServiceCallback(cluster string, err error) { func (xdsC *Client) InvokeWatchServiceCallback(u xdsclient.ServiceUpdate, err error) {
xdsC.mu.Lock() xdsC.mu.Lock()
defer xdsC.mu.Unlock() defer xdsC.mu.Unlock()
xdsC.serviceCb(xdsclient.ServiceUpdate{Cluster: cluster}, err) xdsC.serviceCb(u, err)
} }
// WatchCluster registers a CDS watch. // WatchCluster registers a CDS watch.