rls: LB policy with only control channel handling (#3496)

This commit is contained in:
Easwar Swaminathan 2020-04-28 10:47:24 -07:00 committed by GitHub
parent b2df44eac8
commit b0ac601168
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 654 additions and 107 deletions

View File

@ -0,0 +1,211 @@
// +build go1.10
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rls
import (
"sync"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/resolver"
)
var (
_ balancer.Balancer = (*rlsBalancer)(nil)
_ balancer.V2Balancer = (*rlsBalancer)(nil)
// For overriding in tests.
newRLSClientFunc = newRLSClient
)
// rlsBalancer implements the RLS LB policy.
type rlsBalancer struct {
done *grpcsync.Event
cc balancer.ClientConn
opts balancer.BuildOptions
// Mutex protects all the state maintained by the LB policy.
// TODO(easwars): Once we add the cache, we will also have another lock for
// the cache alone.
mu sync.Mutex
lbCfg *lbConfig // Most recently received service config.
rlsCC *grpc.ClientConn // ClientConn to the RLS server.
rlsC *rlsClient // RLS client wrapper.
ccUpdateCh chan *balancer.ClientConnState
}
// run is a long running goroutine which handles all the updates that the
// balancer wishes to handle. The appropriate updateHandler will push the update
// on to a channel that this goroutine will select on, thereby the handling of
// the update will happen asynchronously.
func (lb *rlsBalancer) run() {
for {
// TODO(easwars): Handle other updates like subConn state changes, RLS
// responses from the server etc.
select {
case u := <-lb.ccUpdateCh:
lb.handleClientConnUpdate(u)
case <-lb.done.Done():
return
}
}
}
// handleClientConnUpdate handles updates to the service config.
// If the RLS server name or the RLS RPC timeout changes, it updates the control
// channel accordingly.
// TODO(easwars): Handle updates to other fields in the service config.
func (lb *rlsBalancer) handleClientConnUpdate(ccs *balancer.ClientConnState) {
grpclog.Infof("rls: service config: %+v", ccs.BalancerConfig)
lb.mu.Lock()
defer lb.mu.Unlock()
if lb.done.HasFired() {
grpclog.Warning("rls: received service config after balancer close")
return
}
newCfg := ccs.BalancerConfig.(*lbConfig)
if lb.lbCfg.Equal(newCfg) {
grpclog.Info("rls: new service config matches existing config")
return
}
lb.updateControlChannel(newCfg)
lb.lbCfg = newCfg
}
// UpdateClientConnState pushes the received ClientConnState update on the
// update channel which will be processed asynchronously by the run goroutine.
// Implements balancer.V2Balancer interface.
func (lb *rlsBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
select {
case lb.ccUpdateCh <- &ccs:
case <-lb.done.Done():
}
return nil
}
// ResolverErr implements balancer.V2Balancer interface.
func (lb *rlsBalancer) ResolverError(error) {
// ResolverError is called by gRPC when the name resolver reports an error.
// TODO(easwars): How do we handle this?
grpclog.Fatal("rls: ResolverError is not yet unimplemented")
}
// UpdateSubConnState implements balancer.V2Balancer interface.
func (lb *rlsBalancer) UpdateSubConnState(_ balancer.SubConn, _ balancer.SubConnState) {
grpclog.Fatal("rls: UpdateSubConnState is not yet implemented")
}
// Cleans up the resources allocated by the LB policy including the clientConn
// to the RLS server.
// Implements balancer.Balancer and balancer.V2Balancer interfaces.
func (lb *rlsBalancer) Close() {
lb.mu.Lock()
defer lb.mu.Unlock()
lb.done.Fire()
if lb.rlsCC != nil {
lb.rlsCC.Close()
}
}
// HandleSubConnStateChange implements balancer.Balancer interface.
func (lb *rlsBalancer) HandleSubConnStateChange(_ balancer.SubConn, _ connectivity.State) {
grpclog.Fatal("UpdateSubConnState should be called instead of HandleSubConnStateChange")
}
// HandleResolvedAddrs implements balancer.Balancer interface.
func (lb *rlsBalancer) HandleResolvedAddrs(_ []resolver.Address, _ error) {
grpclog.Fatal("UpdateClientConnState should be called instead of HandleResolvedAddrs")
}
// updateControlChannel updates the RLS client if required.
// Caller must hold lb.mu.
func (lb *rlsBalancer) updateControlChannel(newCfg *lbConfig) {
oldCfg := lb.lbCfg
if newCfg.lookupService == oldCfg.lookupService && newCfg.lookupServiceTimeout == oldCfg.lookupServiceTimeout {
return
}
// Use RPC timeout from new config, if different from existing one.
timeout := oldCfg.lookupServiceTimeout
if timeout != newCfg.lookupServiceTimeout {
timeout = newCfg.lookupServiceTimeout
}
if newCfg.lookupService == oldCfg.lookupService {
// This is the case where only the timeout has changed. We will continue
// to use the existing clientConn. but will create a new rlsClient with
// the new timeout.
lb.rlsC = newRLSClientFunc(lb.rlsCC, lb.opts.Target.Endpoint, timeout)
return
}
// This is the case where the RLS server name has changed. We need to create
// a new clientConn and close the old one.
var dopts []grpc.DialOption
if dialer := lb.opts.Dialer; dialer != nil {
dopts = append(dopts, grpc.WithContextDialer(dialer))
}
dopts = append(dopts, dialCreds(lb.opts))
cc, err := grpc.Dial(newCfg.lookupService, dopts...)
if err != nil {
grpclog.Errorf("rls: dialRLS(%s, %v): %v", newCfg.lookupService, lb.opts, err)
// An error from a non-blocking dial indicates something serious. We
// should continue to use the old control channel if one exists, and
// return so that the rest of the config updates can be processes.
return
}
if lb.rlsCC != nil {
lb.rlsCC.Close()
}
lb.rlsCC = cc
lb.rlsC = newRLSClientFunc(cc, lb.opts.Target.Endpoint, timeout)
}
func dialCreds(opts balancer.BuildOptions) grpc.DialOption {
// The control channel should use the same authority as that of the parent
// channel. This ensures that the identify of the RLS server and that of the
// backend is the same, so if the RLS config is injected by an attacker, it
// cannot cause leakage of private information contained in headers set by
// the application.
server := opts.Target.Authority
switch {
case opts.DialCreds != nil:
if err := opts.DialCreds.OverrideServerName(server); err != nil {
grpclog.Warningf("rls: OverrideServerName(%s) = (%v), using Insecure", server, err)
return grpc.WithInsecure()
}
return grpc.WithTransportCredentials(opts.DialCreds)
case opts.CredsBundle != nil:
return grpc.WithTransportCredentials(opts.CredsBundle.TransportCredentials())
default:
grpclog.Warning("rls: no credentials available, using Insecure")
return grpc.WithInsecure()
}
}

View File

@ -0,0 +1,228 @@
// +build go1.10
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rls
import (
"net"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/testdata"
)
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
type listenerWrapper struct {
net.Listener
connCh *testutils.Channel
}
// Accept waits for and returns the next connection to the listener.
func (l *listenerWrapper) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
l.connCh.Send(c)
return c, nil
}
func setupwithListener(t *testing.T, opts ...grpc.ServerOption) (*fakeserver.Server, *listenerWrapper, func()) {
t.Helper()
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("net.Listen(tcp, localhost:0): %v", err)
}
lw := &listenerWrapper{
Listener: l,
connCh: testutils.NewChannel(),
}
server, cleanup, err := fakeserver.Start(lw, opts...)
if err != nil {
t.Fatalf("fakeserver.Start(): %v", err)
}
t.Logf("Fake RLS server started at %s ...", server.Address)
return server, lw, cleanup
}
type testBalancerCC struct {
balancer.ClientConn
}
// TestUpdateControlChannelFirstConfig tests the scenario where the LB policy
// receives its first service config and verifies that a control channel to the
// RLS server specified in the serviceConfig is established.
func (s) TestUpdateControlChannelFirstConfig(t *testing.T) {
server, lis, cleanup := setupwithListener(t)
defer cleanup()
bb := balancer.Get(rlsBalancerName)
if bb == nil {
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
}
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}).(balancer.V2Balancer)
defer rlsB.Close()
t.Log("Built RLS LB policy ...")
lbCfg := &lbConfig{lookupService: server.Address}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis.connCh.Receive(); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
}
// TODO: Verify channel connectivity state once control channel connectivity
// state monitoring is in place.
// TODO: Verify RLS RPC can be made once we integrate with the picker.
}
// TestUpdateControlChannelSwitch tests the scenario where a control channel
// exists and the LB policy receives a new serviceConfig with a different RLS
// server name. Verifies that the new control channel is created and the old one
// is closed (the leakchecker takes care of this).
func (s) TestUpdateControlChannelSwitch(t *testing.T) {
server1, lis1, cleanup1 := setupwithListener(t)
defer cleanup1()
server2, lis2, cleanup2 := setupwithListener(t)
defer cleanup2()
bb := balancer.Get(rlsBalancerName)
if bb == nil {
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
}
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}).(balancer.V2Balancer)
defer rlsB.Close()
t.Log("Built RLS LB policy ...")
lbCfg := &lbConfig{lookupService: server1.Address}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis1.connCh.Receive(); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
}
lbCfg = &lbConfig{lookupService: server2.Address}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis2.connCh.Receive(); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
}
// TODO: Verify channel connectivity state once control channel connectivity
// state monitoring is in place.
// TODO: Verify RLS RPC can be made once we integrate with the picker.
}
// TestUpdateControlChannelTimeout tests the scenario where the LB policy
// receives a service config update with a different lookupServiceTimeout, but
// the lookupService itself remains unchanged. It verifies that the LB policy
// does not create a new control channel in this case.
func (s) TestUpdateControlChannelTimeout(t *testing.T) {
server, lis, cleanup := setupwithListener(t)
defer cleanup()
bb := balancer.Get(rlsBalancerName)
if bb == nil {
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
}
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}).(balancer.V2Balancer)
defer rlsB.Close()
t.Log("Built RLS LB policy ...")
lbCfg := &lbConfig{lookupService: server.Address, lookupServiceTimeout: 1 * time.Second}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis.connCh.Receive(); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
}
lbCfg = &lbConfig{lookupService: server.Address, lookupServiceTimeout: 2 * time.Second}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis.connCh.Receive(); err != testutils.ErrRecvTimeout {
t.Fatal("LB policy created new control channel when only lookupServiceTimeout changed")
}
// TODO: Verify channel connectivity state once control channel connectivity
// state monitoring is in place.
// TODO: Verify RLS RPC can be made once we integrate with the picker.
}
// TestUpdateControlChannelWithCreds tests the scenario where the control
// channel is to established with credentials from the parent channel.
func (s) TestUpdateControlChannelWithCreds(t *testing.T) {
sCreds, err := credentials.NewServerTLSFromFile(testdata.Path("server1.pem"), testdata.Path("server1.key"))
if err != nil {
t.Fatalf("credentials.NewServerTLSFromFile(server1.pem, server1.key) = %v", err)
}
cCreds, err := credentials.NewClientTLSFromFile(testdata.Path("ca.pem"), "")
if err != nil {
t.Fatalf("credentials.NewClientTLSFromFile(ca.pem) = %v", err)
}
server, lis, cleanup := setupwithListener(t, grpc.Creds(sCreds))
defer cleanup()
bb := balancer.Get(rlsBalancerName)
if bb == nil {
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
}
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{
DialCreds: cCreds,
}).(balancer.V2Balancer)
defer rlsB.Close()
t.Log("Built RLS LB policy ...")
lbCfg := &lbConfig{lookupService: server.Address}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis.connCh.Receive(); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
}
// TODO: Verify channel connectivity state once control channel connectivity
// state monitoring is in place.
// TODO: Verify RLS RPC can be made once we integrate with the picker.
}

View File

@ -21,16 +21,35 @@
// Package rls implements the RLS LB policy.
package rls
import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/internal/grpcsync"
)
const rlsBalancerName = "rls"
func init() {
balancer.Register(&rlsBB{})
}
// rlsBB helps build RLS load balancers and parse the service config to be
// passed to the RLS load balancer.
type rlsBB struct {
// TODO(easwars): Implement the Build() method and register the builder.
}
type rlsBB struct{}
// Name returns the name of the RLS LB policy and helps implement the
// balancer.Balancer interface.
func (*rlsBB) Name() string {
return rlsBalancerName
}
func (*rlsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
lb := &rlsBalancer{
done: grpcsync.NewEvent(),
cc: cc,
opts: opts,
lbCfg: &lbConfig{},
ccUpdateCh: make(chan *balancer.ClientConnState),
}
go lb.run()
return lb
}

View File

@ -43,7 +43,6 @@ const grpcTargetType = "grpc"
// throttling and asks this client to make an RPC call only after checking with
// the throttler.
type rlsClient struct {
cc *grpc.ClientConn
stub rlspb.RouteLookupServiceClient
// origDialTarget is the original dial target of the user and sent in each
// RouteLookup RPC made to the RLS server.
@ -55,7 +54,6 @@ type rlsClient struct {
func newRLSClient(cc *grpc.ClientConn, dialTarget string, rpcTimeout time.Duration) *rlsClient {
return &rlsClient{
cc: cc,
stub: rlspb.NewRouteLookupServiceClient(cc),
origDialTarget: dialTarget,
rpcTimeout: rpcTimeout,

View File

@ -1,3 +1,5 @@
// +build go1.10
/*
*
* Copyright 2020 gRPC authors.
@ -30,25 +32,26 @@ import (
rlspb "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1"
"google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/status"
)
const (
defaultDialTarget = "dummy"
defaultRPCTimeout = 5 * time.Second
defaultTestTimeout = 1 * time.Second
defaultDialTarget = "dummy"
defaultRPCTimeout = 5 * time.Second
)
func setup(t *testing.T) (*fakeserver.Server, *grpc.ClientConn, func()) {
t.Helper()
server, sCleanup, err := fakeserver.Start()
server, sCleanup, err := fakeserver.Start(nil)
if err != nil {
t.Fatalf("Failed to start fake RLS server: %v", err)
}
cc, cCleanup, err := server.ClientConn()
if err != nil {
sCleanup()
t.Fatalf("Failed to get a ClientConn to the RLS server: %v", err)
}
@ -59,7 +62,7 @@ func setup(t *testing.T) (*fakeserver.Server, *grpc.ClientConn, func()) {
}
// TestLookupFailure verifies the case where the RLS server returns an error.
func TestLookupFailure(t *testing.T) {
func (s) TestLookupFailure(t *testing.T) {
server, cc, cleanup := setup(t)
defer cleanup()
@ -68,64 +71,50 @@ func TestLookupFailure(t *testing.T) {
rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout)
errCh := make(chan error)
errCh := testutils.NewChannel()
rlsClient.lookup("", nil, func(targets []string, headerData string, err error) {
if err == nil {
errCh <- errors.New("rlsClient.lookup() succeeded, should have failed")
errCh.Send(errors.New("rlsClient.lookup() succeeded, should have failed"))
return
}
if len(targets) != 0 || headerData != "" {
errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (nil, \"\")", targets, headerData)
errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (nil, \"\")", targets, headerData))
return
}
errCh <- nil
errCh.Send(nil)
})
timer := time.NewTimer(defaultTestTimeout)
select {
case <-timer.C:
t.Fatal("Timeout when expecting a routeLookup callback")
case err := <-errCh:
timer.Stop()
if err != nil {
t.Fatal(err)
}
if e, err := errCh.Receive(); err != nil || e != nil {
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
}
}
// TestLookupDeadlineExceeded tests the case where the RPC deadline associated
// with the lookup expires.
func TestLookupDeadlineExceeded(t *testing.T) {
func (s) TestLookupDeadlineExceeded(t *testing.T) {
_, cc, cleanup := setup(t)
defer cleanup()
// Give the Lookup RPC a small deadline, but don't setup the fake server to
// return anything. So the Lookup call will block and eventuall expire.
// return anything. So the Lookup call will block and eventually expire.
rlsClient := newRLSClient(cc, defaultDialTarget, 100*time.Millisecond)
errCh := make(chan error)
errCh := testutils.NewChannel()
rlsClient.lookup("", nil, func(_ []string, _ string, err error) {
if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded)
errCh.Send(fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded))
return
}
errCh <- nil
errCh.Send(nil)
})
timer := time.NewTimer(defaultTestTimeout)
select {
case <-timer.C:
t.Fatal("Timeout when expecting a routeLookup callback")
case err := <-errCh:
timer.Stop()
if err != nil {
t.Fatal(err)
}
if e, err := errCh.Receive(); err != nil || e != nil {
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
}
}
// TestLookupSuccess verifies the successful Lookup API case.
func TestLookupSuccess(t *testing.T) {
func (s) TestLookupSuccess(t *testing.T) {
server, cc, cleanup := setup(t)
defer cleanup()
@ -148,33 +137,29 @@ func TestLookupSuccess(t *testing.T) {
rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout)
errCh := make(chan error)
errCh := testutils.NewChannel()
rlsClient.lookup(rlsReqPath, rlsReqKeyMap, func(targets []string, hd string, err error) {
if err != nil {
errCh <- fmt.Errorf("rlsClient.Lookup() failed: %v", err)
errCh.Send(fmt.Errorf("rlsClient.Lookup() failed: %v", err))
return
}
if !cmp.Equal(targets, wantRespTargets) || hd != wantHeaderData {
errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, hd, wantRespTargets, wantHeaderData)
errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, hd, wantRespTargets, wantHeaderData))
return
}
errCh <- nil
errCh.Send(nil)
})
// Make sure that the fake server received the expected RouteLookupRequest
// proto.
timer := time.NewTimer(defaultTestTimeout)
select {
case gotLookupRequest := <-server.RequestChan:
if !timer.Stop() {
<-timer.C
}
if diff := cmp.Diff(wantLookupRequest, gotLookupRequest, cmp.Comparer(proto.Equal)); diff != "" {
t.Fatalf("RouteLookupRequest diff (-want, +got):\n%s", diff)
}
case <-timer.C:
req, err := server.RequestChan.Receive()
if err != nil {
t.Fatalf("Timed out wile waiting for a RouteLookupRequest")
}
gotLookupRequest := req.(*rlspb.RouteLookupRequest)
if diff := cmp.Diff(wantLookupRequest, gotLookupRequest, cmp.Comparer(proto.Equal)); diff != "" {
t.Fatalf("RouteLookupRequest diff (-want, +got):\n%s", diff)
}
// We setup the fake server to return this response when it receives a
// request.
@ -185,14 +170,7 @@ func TestLookupSuccess(t *testing.T) {
},
}
timer = time.NewTimer(defaultTestTimeout)
select {
case <-timer.C:
t.Fatal("Timeout when expecting a routeLookup callback")
case err := <-errCh:
timer.Stop()
if err != nil {
t.Fatal(err)
}
if e, err := errCh.Receive(); err != nil || e != nil {
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
}
}

View File

@ -71,6 +71,40 @@ type lbConfig struct {
cpConfig map[string]json.RawMessage
}
func (lbCfg *lbConfig) Equal(other *lbConfig) bool {
return lbCfg.kbMap.Equal(other.kbMap) &&
lbCfg.lookupService == other.lookupService &&
lbCfg.lookupServiceTimeout == other.lookupServiceTimeout &&
lbCfg.maxAge == other.maxAge &&
lbCfg.staleAge == other.staleAge &&
lbCfg.cacheSizeBytes == other.cacheSizeBytes &&
lbCfg.rpStrategy == other.rpStrategy &&
lbCfg.defaultTarget == other.defaultTarget &&
lbCfg.cpName == other.cpName &&
lbCfg.cpTargetField == other.cpTargetField &&
cpConfigEqual(lbCfg.cpConfig, other.cpConfig)
}
func cpConfigEqual(am, bm map[string]json.RawMessage) bool {
if (bm == nil) != (am == nil) {
return false
}
if len(bm) != len(am) {
return false
}
for k, jsonA := range am {
jsonB, ok := bm[k]
if !ok {
return false
}
if !bytes.Equal(jsonA, jsonB) {
return false
}
}
return true
}
// This struct resembles the JSON respresentation of the loadBalancing config
// and makes it easier to unmarshal.
type lbConfigJSON struct {

View File

@ -49,20 +49,22 @@ func init() {
balancer.Register(&dummyBB{})
}
func (lbCfg *lbConfig) Equal(other *lbConfig) bool {
// This only ignores the keyBuilderMap field because its internals are not
// exported, and hence not possible to specify in the want section of the
// test.
return lbCfg.lookupService == other.lookupService &&
lbCfg.lookupServiceTimeout == other.lookupServiceTimeout &&
lbCfg.maxAge == other.maxAge &&
lbCfg.staleAge == other.staleAge &&
lbCfg.cacheSizeBytes == other.cacheSizeBytes &&
lbCfg.rpStrategy == other.rpStrategy &&
lbCfg.defaultTarget == other.defaultTarget &&
lbCfg.cpName == other.cpName &&
lbCfg.cpTargetField == other.cpTargetField &&
cmp.Equal(lbCfg.cpConfig, other.cpConfig)
// testEqual reports whether the lbCfgs a and b are equal. This is to be used
// only from tests. This ignores the keyBuilderMap field because its internals
// are not exported, and hence not possible to specify in the want section of
// the test. This is fine because we already have tests to make sure that the
// keyBuilder is parsed properly from the service config.
func testEqual(a, b *lbConfig) bool {
return a.lookupService == b.lookupService &&
a.lookupServiceTimeout == b.lookupServiceTimeout &&
a.maxAge == b.maxAge &&
a.staleAge == b.staleAge &&
a.cacheSizeBytes == b.cacheSizeBytes &&
a.rpStrategy == b.rpStrategy &&
a.defaultTarget == b.defaultTarget &&
a.cpName == b.cpName &&
a.cpTargetField == b.cpTargetField &&
cmp.Equal(a.cpConfig, b.cpConfig)
}
func TestParseConfig(t *testing.T) {
@ -152,7 +154,7 @@ func TestParseConfig(t *testing.T) {
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
lbCfg, err := builder.ParseConfig(test.input)
if err != nil || !cmp.Equal(lbCfg, test.wantCfg) {
if err != nil || !testEqual(lbCfg.(*lbConfig), test.wantCfg) {
t.Errorf("ParseConfig(%s) = {%+v, %v}, want {%+v, nil}", string(test.input), lbCfg, err, test.wantCfg)
}
})

View File

@ -28,13 +28,13 @@ import (
"testing"
"time"
"google.golang.org/grpc/internal/grpcrand"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/rls/internal/cache"
"google.golang.org/grpc/balancer/rls/internal/keys"
rlspb "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
)
@ -502,7 +502,7 @@ func TestPick(t *testing.T) {
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
rlsCh := make(chan error, 1)
rlsCh := testutils.NewChannel()
randID := grpcrand.Intn(math.MaxInt32)
// We instantiate a fakeChildPicker which will return a fakeSubConn
// with configured id. Either the childPicker or the defaultPicker
@ -525,18 +525,18 @@ func TestPick(t *testing.T) {
shouldThrottle: func() bool { return test.throttle },
startRLS: func(path string, km keys.KeyMap) {
if !test.newRLSRequest {
rlsCh <- errors.New("RLS request attempted when none was expected")
rlsCh.Send(errors.New("RLS request attempted when none was expected"))
return
}
if path != rpcPath {
rlsCh <- fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath)
rlsCh.Send(fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath))
return
}
if km.Str != wantKeyMapStr {
rlsCh <- fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr)
rlsCh.Send(fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr))
return
}
rlsCh <- nil
rlsCh.Send(nil)
},
defaultPick: func(info balancer.PickInfo) (balancer.PickResult, error) {
if !test.useDefaultPick {
@ -569,15 +569,8 @@ func TestPick(t *testing.T) {
// If the test specified that a new RLS request should be made,
// verify it.
if test.newRLSRequest {
timer := time.NewTimer(defaultTestTimeout)
select {
case err := <-rlsCh:
timer.Stop()
if err != nil {
t.Fatal(err)
}
case <-timer.C:
t.Fatal("Timeout waiting for RLS request to be sent out")
if rlsErr, err := rlsCh.Receive(); err != nil || rlsErr != nil {
t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err)
}
}
})

View File

@ -22,6 +22,7 @@ package fakeserver
import (
"context"
"errors"
"fmt"
"net"
"time"
@ -29,9 +30,14 @@ import (
"google.golang.org/grpc"
rlsgrpc "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1"
rlspb "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1"
"google.golang.org/grpc/internal/testutils"
)
const defaultDialTimeout = 5 * time.Second
const (
defaultDialTimeout = 5 * time.Second
defaultRPCTimeout = 5 * time.Second
defaultChannelBufferSize = 50
)
// Response wraps the response protobuf (xds/LRS) and error that the Server
// should send out to the client through a call to stream.Send()
@ -43,29 +49,31 @@ type Response struct {
// Server is a fake implementation of RLS. It exposes channels to send/receive
// RLS requests and responses.
type Server struct {
RequestChan chan *rlspb.RouteLookupRequest
RequestChan *testutils.Channel
ResponseChan chan Response
Address string
}
// Start makes a new Server and gets it to start listening on a local port for
// gRPC requests. The returned cancel function should be invoked by the caller
// upon completion of the test.
func Start() (*Server, func(), error) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, func() {}, fmt.Errorf("net.Listen() failed: %v", err)
// Start makes a new Server which uses the provided net.Listener. If lis is nil,
// it creates a new net.Listener on a local port. The returned cancel function
// should be invoked by the caller upon completion of the test.
func Start(lis net.Listener, opts ...grpc.ServerOption) (*Server, func(), error) {
if lis == nil {
var err error
lis, err = net.Listen("tcp", "localhost:0")
if err != nil {
return nil, func() {}, fmt.Errorf("net.Listen() failed: %v", err)
}
}
s := &Server{
// Give the channels a buffer size of 1 so that we can setup
// expectations for one lookup call, without blocking.
RequestChan: make(chan *rlspb.RouteLookupRequest, 1),
RequestChan: testutils.NewChannelWithSize(defaultChannelBufferSize),
ResponseChan: make(chan Response, 1),
Address: lis.Addr().String(),
}
server := grpc.NewServer()
server := grpc.NewServer(opts...)
rlsgrpc.RegisterRouteLookupServiceServer(server, s)
go server.Serve(lis)
@ -74,9 +82,17 @@ func Start() (*Server, func(), error) {
// RouteLookup implements the RouteLookupService.
func (s *Server) RouteLookup(ctx context.Context, req *rlspb.RouteLookupRequest) (*rlspb.RouteLookupResponse, error) {
s.RequestChan <- req
resp := <-s.ResponseChan
return resp.Resp, resp.Err
s.RequestChan.Send(req)
// The leakchecker fails if we don't exit out of here in a reasonable time.
timer := time.NewTimer(defaultRPCTimeout)
select {
case <-timer.C:
return nil, errors.New("default RPC timeout exceeded")
case resp := <-s.ResponseChan:
timer.Stop()
return resp.Resp, resp.Err
}
}
// ClientConn returns a grpc.ClientConn connected to the fakeServer.

View File

@ -0,0 +1,68 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package testutils
import (
"errors"
"time"
)
// ErrRecvTimeout is an error to indicate that a receive operation on the
// channel timed out.
var ErrRecvTimeout = errors.New("timed out when waiting for value on channel")
const (
// DefaultChanRecvTimeout is the default timeout for receive operations on the
// underlying channel.
DefaultChanRecvTimeout = 1 * time.Second
// DefaultChanBufferSize is the default buffer size of the underlying channel.
DefaultChanBufferSize = 1
)
// Channel wraps a generic channel and provides a timed receive operation.
type Channel struct {
ch chan interface{}
}
// Send sends value on the underlying channel.
func (cwt *Channel) Send(value interface{}) {
cwt.ch <- value
}
// Receive returns the value received on the underlying channel, or
// ErrRecvTimeout if DefaultChanRecvTimeout amount of time elapses.
func (cwt *Channel) Receive() (interface{}, error) {
timer := time.NewTimer(DefaultChanRecvTimeout)
select {
case <-timer.C:
return nil, ErrRecvTimeout
case got := <-cwt.ch:
timer.Stop()
return got, nil
}
}
// NewChannel returns a new Channel.
func NewChannel() *Channel {
return NewChannelWithSize(DefaultChanBufferSize)
}
// NewChannelWithSize returns a new Channel with a buffer of bufSize.
func NewChannelWithSize(bufSize int) *Channel {
return &Channel{ch: make(chan interface{}, bufSize)}
}