mirror of https://github.com/grpc/grpc-go.git
rls: LB policy with only control channel handling (#3496)
This commit is contained in:
parent
b2df44eac8
commit
b0ac601168
|
|
@ -0,0 +1,211 @@
|
|||
// +build go1.10
|
||||
|
||||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package rls
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/internal/grpcsync"
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
var (
|
||||
_ balancer.Balancer = (*rlsBalancer)(nil)
|
||||
_ balancer.V2Balancer = (*rlsBalancer)(nil)
|
||||
|
||||
// For overriding in tests.
|
||||
newRLSClientFunc = newRLSClient
|
||||
)
|
||||
|
||||
// rlsBalancer implements the RLS LB policy.
|
||||
type rlsBalancer struct {
|
||||
done *grpcsync.Event
|
||||
cc balancer.ClientConn
|
||||
opts balancer.BuildOptions
|
||||
|
||||
// Mutex protects all the state maintained by the LB policy.
|
||||
// TODO(easwars): Once we add the cache, we will also have another lock for
|
||||
// the cache alone.
|
||||
mu sync.Mutex
|
||||
lbCfg *lbConfig // Most recently received service config.
|
||||
rlsCC *grpc.ClientConn // ClientConn to the RLS server.
|
||||
rlsC *rlsClient // RLS client wrapper.
|
||||
|
||||
ccUpdateCh chan *balancer.ClientConnState
|
||||
}
|
||||
|
||||
// run is a long running goroutine which handles all the updates that the
|
||||
// balancer wishes to handle. The appropriate updateHandler will push the update
|
||||
// on to a channel that this goroutine will select on, thereby the handling of
|
||||
// the update will happen asynchronously.
|
||||
func (lb *rlsBalancer) run() {
|
||||
for {
|
||||
// TODO(easwars): Handle other updates like subConn state changes, RLS
|
||||
// responses from the server etc.
|
||||
select {
|
||||
case u := <-lb.ccUpdateCh:
|
||||
lb.handleClientConnUpdate(u)
|
||||
case <-lb.done.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleClientConnUpdate handles updates to the service config.
|
||||
// If the RLS server name or the RLS RPC timeout changes, it updates the control
|
||||
// channel accordingly.
|
||||
// TODO(easwars): Handle updates to other fields in the service config.
|
||||
func (lb *rlsBalancer) handleClientConnUpdate(ccs *balancer.ClientConnState) {
|
||||
grpclog.Infof("rls: service config: %+v", ccs.BalancerConfig)
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
if lb.done.HasFired() {
|
||||
grpclog.Warning("rls: received service config after balancer close")
|
||||
return
|
||||
}
|
||||
|
||||
newCfg := ccs.BalancerConfig.(*lbConfig)
|
||||
if lb.lbCfg.Equal(newCfg) {
|
||||
grpclog.Info("rls: new service config matches existing config")
|
||||
return
|
||||
}
|
||||
|
||||
lb.updateControlChannel(newCfg)
|
||||
lb.lbCfg = newCfg
|
||||
}
|
||||
|
||||
// UpdateClientConnState pushes the received ClientConnState update on the
|
||||
// update channel which will be processed asynchronously by the run goroutine.
|
||||
// Implements balancer.V2Balancer interface.
|
||||
func (lb *rlsBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
|
||||
select {
|
||||
case lb.ccUpdateCh <- &ccs:
|
||||
case <-lb.done.Done():
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolverErr implements balancer.V2Balancer interface.
|
||||
func (lb *rlsBalancer) ResolverError(error) {
|
||||
// ResolverError is called by gRPC when the name resolver reports an error.
|
||||
// TODO(easwars): How do we handle this?
|
||||
grpclog.Fatal("rls: ResolverError is not yet unimplemented")
|
||||
}
|
||||
|
||||
// UpdateSubConnState implements balancer.V2Balancer interface.
|
||||
func (lb *rlsBalancer) UpdateSubConnState(_ balancer.SubConn, _ balancer.SubConnState) {
|
||||
grpclog.Fatal("rls: UpdateSubConnState is not yet implemented")
|
||||
}
|
||||
|
||||
// Cleans up the resources allocated by the LB policy including the clientConn
|
||||
// to the RLS server.
|
||||
// Implements balancer.Balancer and balancer.V2Balancer interfaces.
|
||||
func (lb *rlsBalancer) Close() {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
lb.done.Fire()
|
||||
if lb.rlsCC != nil {
|
||||
lb.rlsCC.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSubConnStateChange implements balancer.Balancer interface.
|
||||
func (lb *rlsBalancer) HandleSubConnStateChange(_ balancer.SubConn, _ connectivity.State) {
|
||||
grpclog.Fatal("UpdateSubConnState should be called instead of HandleSubConnStateChange")
|
||||
}
|
||||
|
||||
// HandleResolvedAddrs implements balancer.Balancer interface.
|
||||
func (lb *rlsBalancer) HandleResolvedAddrs(_ []resolver.Address, _ error) {
|
||||
grpclog.Fatal("UpdateClientConnState should be called instead of HandleResolvedAddrs")
|
||||
}
|
||||
|
||||
// updateControlChannel updates the RLS client if required.
|
||||
// Caller must hold lb.mu.
|
||||
func (lb *rlsBalancer) updateControlChannel(newCfg *lbConfig) {
|
||||
oldCfg := lb.lbCfg
|
||||
if newCfg.lookupService == oldCfg.lookupService && newCfg.lookupServiceTimeout == oldCfg.lookupServiceTimeout {
|
||||
return
|
||||
}
|
||||
|
||||
// Use RPC timeout from new config, if different from existing one.
|
||||
timeout := oldCfg.lookupServiceTimeout
|
||||
if timeout != newCfg.lookupServiceTimeout {
|
||||
timeout = newCfg.lookupServiceTimeout
|
||||
}
|
||||
|
||||
if newCfg.lookupService == oldCfg.lookupService {
|
||||
// This is the case where only the timeout has changed. We will continue
|
||||
// to use the existing clientConn. but will create a new rlsClient with
|
||||
// the new timeout.
|
||||
lb.rlsC = newRLSClientFunc(lb.rlsCC, lb.opts.Target.Endpoint, timeout)
|
||||
return
|
||||
}
|
||||
|
||||
// This is the case where the RLS server name has changed. We need to create
|
||||
// a new clientConn and close the old one.
|
||||
var dopts []grpc.DialOption
|
||||
if dialer := lb.opts.Dialer; dialer != nil {
|
||||
dopts = append(dopts, grpc.WithContextDialer(dialer))
|
||||
}
|
||||
dopts = append(dopts, dialCreds(lb.opts))
|
||||
|
||||
cc, err := grpc.Dial(newCfg.lookupService, dopts...)
|
||||
if err != nil {
|
||||
grpclog.Errorf("rls: dialRLS(%s, %v): %v", newCfg.lookupService, lb.opts, err)
|
||||
// An error from a non-blocking dial indicates something serious. We
|
||||
// should continue to use the old control channel if one exists, and
|
||||
// return so that the rest of the config updates can be processes.
|
||||
return
|
||||
}
|
||||
if lb.rlsCC != nil {
|
||||
lb.rlsCC.Close()
|
||||
}
|
||||
lb.rlsCC = cc
|
||||
lb.rlsC = newRLSClientFunc(cc, lb.opts.Target.Endpoint, timeout)
|
||||
}
|
||||
|
||||
func dialCreds(opts balancer.BuildOptions) grpc.DialOption {
|
||||
// The control channel should use the same authority as that of the parent
|
||||
// channel. This ensures that the identify of the RLS server and that of the
|
||||
// backend is the same, so if the RLS config is injected by an attacker, it
|
||||
// cannot cause leakage of private information contained in headers set by
|
||||
// the application.
|
||||
server := opts.Target.Authority
|
||||
switch {
|
||||
case opts.DialCreds != nil:
|
||||
if err := opts.DialCreds.OverrideServerName(server); err != nil {
|
||||
grpclog.Warningf("rls: OverrideServerName(%s) = (%v), using Insecure", server, err)
|
||||
return grpc.WithInsecure()
|
||||
}
|
||||
return grpc.WithTransportCredentials(opts.DialCreds)
|
||||
case opts.CredsBundle != nil:
|
||||
return grpc.WithTransportCredentials(opts.CredsBundle.TransportCredentials())
|
||||
default:
|
||||
grpclog.Warning("rls: no credentials available, using Insecure")
|
||||
return grpc.WithInsecure()
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,228 @@
|
|||
// +build go1.10
|
||||
|
||||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package rls
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/testdata"
|
||||
)
|
||||
|
||||
type s struct {
|
||||
grpctest.Tester
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
grpctest.RunSubTests(t, s{})
|
||||
}
|
||||
|
||||
type listenerWrapper struct {
|
||||
net.Listener
|
||||
connCh *testutils.Channel
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (l *listenerWrapper) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.connCh.Send(c)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func setupwithListener(t *testing.T, opts ...grpc.ServerOption) (*fakeserver.Server, *listenerWrapper, func()) {
|
||||
t.Helper()
|
||||
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen(tcp, localhost:0): %v", err)
|
||||
}
|
||||
lw := &listenerWrapper{
|
||||
Listener: l,
|
||||
connCh: testutils.NewChannel(),
|
||||
}
|
||||
|
||||
server, cleanup, err := fakeserver.Start(lw, opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("fakeserver.Start(): %v", err)
|
||||
}
|
||||
t.Logf("Fake RLS server started at %s ...", server.Address)
|
||||
|
||||
return server, lw, cleanup
|
||||
}
|
||||
|
||||
type testBalancerCC struct {
|
||||
balancer.ClientConn
|
||||
}
|
||||
|
||||
// TestUpdateControlChannelFirstConfig tests the scenario where the LB policy
|
||||
// receives its first service config and verifies that a control channel to the
|
||||
// RLS server specified in the serviceConfig is established.
|
||||
func (s) TestUpdateControlChannelFirstConfig(t *testing.T) {
|
||||
server, lis, cleanup := setupwithListener(t)
|
||||
defer cleanup()
|
||||
|
||||
bb := balancer.Get(rlsBalancerName)
|
||||
if bb == nil {
|
||||
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
|
||||
}
|
||||
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}).(balancer.V2Balancer)
|
||||
defer rlsB.Close()
|
||||
t.Log("Built RLS LB policy ...")
|
||||
|
||||
lbCfg := &lbConfig{lookupService: server.Address}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
|
||||
if _, err := lis.connCh.Receive(); err != nil {
|
||||
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
|
||||
}
|
||||
|
||||
// TODO: Verify channel connectivity state once control channel connectivity
|
||||
// state monitoring is in place.
|
||||
|
||||
// TODO: Verify RLS RPC can be made once we integrate with the picker.
|
||||
}
|
||||
|
||||
// TestUpdateControlChannelSwitch tests the scenario where a control channel
|
||||
// exists and the LB policy receives a new serviceConfig with a different RLS
|
||||
// server name. Verifies that the new control channel is created and the old one
|
||||
// is closed (the leakchecker takes care of this).
|
||||
func (s) TestUpdateControlChannelSwitch(t *testing.T) {
|
||||
server1, lis1, cleanup1 := setupwithListener(t)
|
||||
defer cleanup1()
|
||||
|
||||
server2, lis2, cleanup2 := setupwithListener(t)
|
||||
defer cleanup2()
|
||||
|
||||
bb := balancer.Get(rlsBalancerName)
|
||||
if bb == nil {
|
||||
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
|
||||
}
|
||||
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}).(balancer.V2Balancer)
|
||||
defer rlsB.Close()
|
||||
t.Log("Built RLS LB policy ...")
|
||||
|
||||
lbCfg := &lbConfig{lookupService: server1.Address}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
|
||||
if _, err := lis1.connCh.Receive(); err != nil {
|
||||
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
|
||||
}
|
||||
|
||||
lbCfg = &lbConfig{lookupService: server2.Address}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
|
||||
if _, err := lis2.connCh.Receive(); err != nil {
|
||||
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
|
||||
}
|
||||
|
||||
// TODO: Verify channel connectivity state once control channel connectivity
|
||||
// state monitoring is in place.
|
||||
|
||||
// TODO: Verify RLS RPC can be made once we integrate with the picker.
|
||||
}
|
||||
|
||||
// TestUpdateControlChannelTimeout tests the scenario where the LB policy
|
||||
// receives a service config update with a different lookupServiceTimeout, but
|
||||
// the lookupService itself remains unchanged. It verifies that the LB policy
|
||||
// does not create a new control channel in this case.
|
||||
func (s) TestUpdateControlChannelTimeout(t *testing.T) {
|
||||
server, lis, cleanup := setupwithListener(t)
|
||||
defer cleanup()
|
||||
|
||||
bb := balancer.Get(rlsBalancerName)
|
||||
if bb == nil {
|
||||
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
|
||||
}
|
||||
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}).(balancer.V2Balancer)
|
||||
defer rlsB.Close()
|
||||
t.Log("Built RLS LB policy ...")
|
||||
|
||||
lbCfg := &lbConfig{lookupService: server.Address, lookupServiceTimeout: 1 * time.Second}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
if _, err := lis.connCh.Receive(); err != nil {
|
||||
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
|
||||
}
|
||||
|
||||
lbCfg = &lbConfig{lookupService: server.Address, lookupServiceTimeout: 2 * time.Second}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
if _, err := lis.connCh.Receive(); err != testutils.ErrRecvTimeout {
|
||||
t.Fatal("LB policy created new control channel when only lookupServiceTimeout changed")
|
||||
}
|
||||
|
||||
// TODO: Verify channel connectivity state once control channel connectivity
|
||||
// state monitoring is in place.
|
||||
|
||||
// TODO: Verify RLS RPC can be made once we integrate with the picker.
|
||||
}
|
||||
|
||||
// TestUpdateControlChannelWithCreds tests the scenario where the control
|
||||
// channel is to established with credentials from the parent channel.
|
||||
func (s) TestUpdateControlChannelWithCreds(t *testing.T) {
|
||||
sCreds, err := credentials.NewServerTLSFromFile(testdata.Path("server1.pem"), testdata.Path("server1.key"))
|
||||
if err != nil {
|
||||
t.Fatalf("credentials.NewServerTLSFromFile(server1.pem, server1.key) = %v", err)
|
||||
}
|
||||
cCreds, err := credentials.NewClientTLSFromFile(testdata.Path("ca.pem"), "")
|
||||
if err != nil {
|
||||
t.Fatalf("credentials.NewClientTLSFromFile(ca.pem) = %v", err)
|
||||
}
|
||||
|
||||
server, lis, cleanup := setupwithListener(t, grpc.Creds(sCreds))
|
||||
defer cleanup()
|
||||
|
||||
bb := balancer.Get(rlsBalancerName)
|
||||
if bb == nil {
|
||||
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
|
||||
}
|
||||
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{
|
||||
DialCreds: cCreds,
|
||||
}).(balancer.V2Balancer)
|
||||
defer rlsB.Close()
|
||||
t.Log("Built RLS LB policy ...")
|
||||
|
||||
lbCfg := &lbConfig{lookupService: server.Address}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
|
||||
if _, err := lis.connCh.Receive(); err != nil {
|
||||
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
|
||||
}
|
||||
|
||||
// TODO: Verify channel connectivity state once control channel connectivity
|
||||
// state monitoring is in place.
|
||||
|
||||
// TODO: Verify RLS RPC can be made once we integrate with the picker.
|
||||
}
|
||||
|
|
@ -21,16 +21,35 @@
|
|||
// Package rls implements the RLS LB policy.
|
||||
package rls
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/internal/grpcsync"
|
||||
)
|
||||
|
||||
const rlsBalancerName = "rls"
|
||||
|
||||
func init() {
|
||||
balancer.Register(&rlsBB{})
|
||||
}
|
||||
|
||||
// rlsBB helps build RLS load balancers and parse the service config to be
|
||||
// passed to the RLS load balancer.
|
||||
type rlsBB struct {
|
||||
// TODO(easwars): Implement the Build() method and register the builder.
|
||||
}
|
||||
type rlsBB struct{}
|
||||
|
||||
// Name returns the name of the RLS LB policy and helps implement the
|
||||
// balancer.Balancer interface.
|
||||
func (*rlsBB) Name() string {
|
||||
return rlsBalancerName
|
||||
}
|
||||
|
||||
func (*rlsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
|
||||
lb := &rlsBalancer{
|
||||
done: grpcsync.NewEvent(),
|
||||
cc: cc,
|
||||
opts: opts,
|
||||
lbCfg: &lbConfig{},
|
||||
ccUpdateCh: make(chan *balancer.ClientConnState),
|
||||
}
|
||||
go lb.run()
|
||||
return lb
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ const grpcTargetType = "grpc"
|
|||
// throttling and asks this client to make an RPC call only after checking with
|
||||
// the throttler.
|
||||
type rlsClient struct {
|
||||
cc *grpc.ClientConn
|
||||
stub rlspb.RouteLookupServiceClient
|
||||
// origDialTarget is the original dial target of the user and sent in each
|
||||
// RouteLookup RPC made to the RLS server.
|
||||
|
|
@ -55,7 +54,6 @@ type rlsClient struct {
|
|||
|
||||
func newRLSClient(cc *grpc.ClientConn, dialTarget string, rpcTimeout time.Duration) *rlsClient {
|
||||
return &rlsClient{
|
||||
cc: cc,
|
||||
stub: rlspb.NewRouteLookupServiceClient(cc),
|
||||
origDialTarget: dialTarget,
|
||||
rpcTimeout: rpcTimeout,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
// +build go1.10
|
||||
|
||||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
|
|
@ -30,25 +32,26 @@ import (
|
|||
rlspb "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1"
|
||||
"google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDialTarget = "dummy"
|
||||
defaultRPCTimeout = 5 * time.Second
|
||||
defaultTestTimeout = 1 * time.Second
|
||||
defaultDialTarget = "dummy"
|
||||
defaultRPCTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
func setup(t *testing.T) (*fakeserver.Server, *grpc.ClientConn, func()) {
|
||||
t.Helper()
|
||||
|
||||
server, sCleanup, err := fakeserver.Start()
|
||||
server, sCleanup, err := fakeserver.Start(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start fake RLS server: %v", err)
|
||||
}
|
||||
|
||||
cc, cCleanup, err := server.ClientConn()
|
||||
if err != nil {
|
||||
sCleanup()
|
||||
t.Fatalf("Failed to get a ClientConn to the RLS server: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -59,7 +62,7 @@ func setup(t *testing.T) (*fakeserver.Server, *grpc.ClientConn, func()) {
|
|||
}
|
||||
|
||||
// TestLookupFailure verifies the case where the RLS server returns an error.
|
||||
func TestLookupFailure(t *testing.T) {
|
||||
func (s) TestLookupFailure(t *testing.T) {
|
||||
server, cc, cleanup := setup(t)
|
||||
defer cleanup()
|
||||
|
||||
|
|
@ -68,64 +71,50 @@ func TestLookupFailure(t *testing.T) {
|
|||
|
||||
rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout)
|
||||
|
||||
errCh := make(chan error)
|
||||
errCh := testutils.NewChannel()
|
||||
rlsClient.lookup("", nil, func(targets []string, headerData string, err error) {
|
||||
if err == nil {
|
||||
errCh <- errors.New("rlsClient.lookup() succeeded, should have failed")
|
||||
errCh.Send(errors.New("rlsClient.lookup() succeeded, should have failed"))
|
||||
return
|
||||
}
|
||||
if len(targets) != 0 || headerData != "" {
|
||||
errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (nil, \"\")", targets, headerData)
|
||||
errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (nil, \"\")", targets, headerData))
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
errCh.Send(nil)
|
||||
})
|
||||
|
||||
timer := time.NewTimer(defaultTestTimeout)
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Fatal("Timeout when expecting a routeLookup callback")
|
||||
case err := <-errCh:
|
||||
timer.Stop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if e, err := errCh.Receive(); err != nil || e != nil {
|
||||
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLookupDeadlineExceeded tests the case where the RPC deadline associated
|
||||
// with the lookup expires.
|
||||
func TestLookupDeadlineExceeded(t *testing.T) {
|
||||
func (s) TestLookupDeadlineExceeded(t *testing.T) {
|
||||
_, cc, cleanup := setup(t)
|
||||
defer cleanup()
|
||||
|
||||
// Give the Lookup RPC a small deadline, but don't setup the fake server to
|
||||
// return anything. So the Lookup call will block and eventuall expire.
|
||||
// return anything. So the Lookup call will block and eventually expire.
|
||||
rlsClient := newRLSClient(cc, defaultDialTarget, 100*time.Millisecond)
|
||||
|
||||
errCh := make(chan error)
|
||||
errCh := testutils.NewChannel()
|
||||
rlsClient.lookup("", nil, func(_ []string, _ string, err error) {
|
||||
if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
|
||||
errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded)
|
||||
errCh.Send(fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded))
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
errCh.Send(nil)
|
||||
})
|
||||
|
||||
timer := time.NewTimer(defaultTestTimeout)
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Fatal("Timeout when expecting a routeLookup callback")
|
||||
case err := <-errCh:
|
||||
timer.Stop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if e, err := errCh.Receive(); err != nil || e != nil {
|
||||
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLookupSuccess verifies the successful Lookup API case.
|
||||
func TestLookupSuccess(t *testing.T) {
|
||||
func (s) TestLookupSuccess(t *testing.T) {
|
||||
server, cc, cleanup := setup(t)
|
||||
defer cleanup()
|
||||
|
||||
|
|
@ -148,33 +137,29 @@ func TestLookupSuccess(t *testing.T) {
|
|||
|
||||
rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout)
|
||||
|
||||
errCh := make(chan error)
|
||||
errCh := testutils.NewChannel()
|
||||
rlsClient.lookup(rlsReqPath, rlsReqKeyMap, func(targets []string, hd string, err error) {
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("rlsClient.Lookup() failed: %v", err)
|
||||
errCh.Send(fmt.Errorf("rlsClient.Lookup() failed: %v", err))
|
||||
return
|
||||
}
|
||||
if !cmp.Equal(targets, wantRespTargets) || hd != wantHeaderData {
|
||||
errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, hd, wantRespTargets, wantHeaderData)
|
||||
errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, hd, wantRespTargets, wantHeaderData))
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
errCh.Send(nil)
|
||||
})
|
||||
|
||||
// Make sure that the fake server received the expected RouteLookupRequest
|
||||
// proto.
|
||||
timer := time.NewTimer(defaultTestTimeout)
|
||||
select {
|
||||
case gotLookupRequest := <-server.RequestChan:
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
if diff := cmp.Diff(wantLookupRequest, gotLookupRequest, cmp.Comparer(proto.Equal)); diff != "" {
|
||||
t.Fatalf("RouteLookupRequest diff (-want, +got):\n%s", diff)
|
||||
}
|
||||
case <-timer.C:
|
||||
req, err := server.RequestChan.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("Timed out wile waiting for a RouteLookupRequest")
|
||||
}
|
||||
gotLookupRequest := req.(*rlspb.RouteLookupRequest)
|
||||
if diff := cmp.Diff(wantLookupRequest, gotLookupRequest, cmp.Comparer(proto.Equal)); diff != "" {
|
||||
t.Fatalf("RouteLookupRequest diff (-want, +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// We setup the fake server to return this response when it receives a
|
||||
// request.
|
||||
|
|
@ -185,14 +170,7 @@ func TestLookupSuccess(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
timer = time.NewTimer(defaultTestTimeout)
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Fatal("Timeout when expecting a routeLookup callback")
|
||||
case err := <-errCh:
|
||||
timer.Stop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if e, err := errCh.Receive(); err != nil || e != nil {
|
||||
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -71,6 +71,40 @@ type lbConfig struct {
|
|||
cpConfig map[string]json.RawMessage
|
||||
}
|
||||
|
||||
func (lbCfg *lbConfig) Equal(other *lbConfig) bool {
|
||||
return lbCfg.kbMap.Equal(other.kbMap) &&
|
||||
lbCfg.lookupService == other.lookupService &&
|
||||
lbCfg.lookupServiceTimeout == other.lookupServiceTimeout &&
|
||||
lbCfg.maxAge == other.maxAge &&
|
||||
lbCfg.staleAge == other.staleAge &&
|
||||
lbCfg.cacheSizeBytes == other.cacheSizeBytes &&
|
||||
lbCfg.rpStrategy == other.rpStrategy &&
|
||||
lbCfg.defaultTarget == other.defaultTarget &&
|
||||
lbCfg.cpName == other.cpName &&
|
||||
lbCfg.cpTargetField == other.cpTargetField &&
|
||||
cpConfigEqual(lbCfg.cpConfig, other.cpConfig)
|
||||
}
|
||||
|
||||
func cpConfigEqual(am, bm map[string]json.RawMessage) bool {
|
||||
if (bm == nil) != (am == nil) {
|
||||
return false
|
||||
}
|
||||
if len(bm) != len(am) {
|
||||
return false
|
||||
}
|
||||
|
||||
for k, jsonA := range am {
|
||||
jsonB, ok := bm[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(jsonA, jsonB) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// This struct resembles the JSON respresentation of the loadBalancing config
|
||||
// and makes it easier to unmarshal.
|
||||
type lbConfigJSON struct {
|
||||
|
|
|
|||
|
|
@ -49,20 +49,22 @@ func init() {
|
|||
balancer.Register(&dummyBB{})
|
||||
}
|
||||
|
||||
func (lbCfg *lbConfig) Equal(other *lbConfig) bool {
|
||||
// This only ignores the keyBuilderMap field because its internals are not
|
||||
// exported, and hence not possible to specify in the want section of the
|
||||
// test.
|
||||
return lbCfg.lookupService == other.lookupService &&
|
||||
lbCfg.lookupServiceTimeout == other.lookupServiceTimeout &&
|
||||
lbCfg.maxAge == other.maxAge &&
|
||||
lbCfg.staleAge == other.staleAge &&
|
||||
lbCfg.cacheSizeBytes == other.cacheSizeBytes &&
|
||||
lbCfg.rpStrategy == other.rpStrategy &&
|
||||
lbCfg.defaultTarget == other.defaultTarget &&
|
||||
lbCfg.cpName == other.cpName &&
|
||||
lbCfg.cpTargetField == other.cpTargetField &&
|
||||
cmp.Equal(lbCfg.cpConfig, other.cpConfig)
|
||||
// testEqual reports whether the lbCfgs a and b are equal. This is to be used
|
||||
// only from tests. This ignores the keyBuilderMap field because its internals
|
||||
// are not exported, and hence not possible to specify in the want section of
|
||||
// the test. This is fine because we already have tests to make sure that the
|
||||
// keyBuilder is parsed properly from the service config.
|
||||
func testEqual(a, b *lbConfig) bool {
|
||||
return a.lookupService == b.lookupService &&
|
||||
a.lookupServiceTimeout == b.lookupServiceTimeout &&
|
||||
a.maxAge == b.maxAge &&
|
||||
a.staleAge == b.staleAge &&
|
||||
a.cacheSizeBytes == b.cacheSizeBytes &&
|
||||
a.rpStrategy == b.rpStrategy &&
|
||||
a.defaultTarget == b.defaultTarget &&
|
||||
a.cpName == b.cpName &&
|
||||
a.cpTargetField == b.cpTargetField &&
|
||||
cmp.Equal(a.cpConfig, b.cpConfig)
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
|
|
@ -152,7 +154,7 @@ func TestParseConfig(t *testing.T) {
|
|||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
lbCfg, err := builder.ParseConfig(test.input)
|
||||
if err != nil || !cmp.Equal(lbCfg, test.wantCfg) {
|
||||
if err != nil || !testEqual(lbCfg.(*lbConfig), test.wantCfg) {
|
||||
t.Errorf("ParseConfig(%s) = {%+v, %v}, want {%+v, nil}", string(test.input), lbCfg, err, test.wantCfg)
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -28,13 +28,13 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/internal/grpcrand"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/rls/internal/cache"
|
||||
"google.golang.org/grpc/balancer/rls/internal/keys"
|
||||
rlspb "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1"
|
||||
"google.golang.org/grpc/internal/grpcrand"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
|
|
@ -502,7 +502,7 @@ func TestPick(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
rlsCh := make(chan error, 1)
|
||||
rlsCh := testutils.NewChannel()
|
||||
randID := grpcrand.Intn(math.MaxInt32)
|
||||
// We instantiate a fakeChildPicker which will return a fakeSubConn
|
||||
// with configured id. Either the childPicker or the defaultPicker
|
||||
|
|
@ -525,18 +525,18 @@ func TestPick(t *testing.T) {
|
|||
shouldThrottle: func() bool { return test.throttle },
|
||||
startRLS: func(path string, km keys.KeyMap) {
|
||||
if !test.newRLSRequest {
|
||||
rlsCh <- errors.New("RLS request attempted when none was expected")
|
||||
rlsCh.Send(errors.New("RLS request attempted when none was expected"))
|
||||
return
|
||||
}
|
||||
if path != rpcPath {
|
||||
rlsCh <- fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath)
|
||||
rlsCh.Send(fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath))
|
||||
return
|
||||
}
|
||||
if km.Str != wantKeyMapStr {
|
||||
rlsCh <- fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr)
|
||||
rlsCh.Send(fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr))
|
||||
return
|
||||
}
|
||||
rlsCh <- nil
|
||||
rlsCh.Send(nil)
|
||||
},
|
||||
defaultPick: func(info balancer.PickInfo) (balancer.PickResult, error) {
|
||||
if !test.useDefaultPick {
|
||||
|
|
@ -569,15 +569,8 @@ func TestPick(t *testing.T) {
|
|||
// If the test specified that a new RLS request should be made,
|
||||
// verify it.
|
||||
if test.newRLSRequest {
|
||||
timer := time.NewTimer(defaultTestTimeout)
|
||||
select {
|
||||
case err := <-rlsCh:
|
||||
timer.Stop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
case <-timer.C:
|
||||
t.Fatal("Timeout waiting for RLS request to be sent out")
|
||||
if rlsErr, err := rlsCh.Receive(); err != nil || rlsErr != nil {
|
||||
t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ package fakeserver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
|
@ -29,9 +30,14 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
rlsgrpc "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1"
|
||||
rlspb "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
)
|
||||
|
||||
const defaultDialTimeout = 5 * time.Second
|
||||
const (
|
||||
defaultDialTimeout = 5 * time.Second
|
||||
defaultRPCTimeout = 5 * time.Second
|
||||
defaultChannelBufferSize = 50
|
||||
)
|
||||
|
||||
// Response wraps the response protobuf (xds/LRS) and error that the Server
|
||||
// should send out to the client through a call to stream.Send()
|
||||
|
|
@ -43,29 +49,31 @@ type Response struct {
|
|||
// Server is a fake implementation of RLS. It exposes channels to send/receive
|
||||
// RLS requests and responses.
|
||||
type Server struct {
|
||||
RequestChan chan *rlspb.RouteLookupRequest
|
||||
RequestChan *testutils.Channel
|
||||
ResponseChan chan Response
|
||||
Address string
|
||||
}
|
||||
|
||||
// Start makes a new Server and gets it to start listening on a local port for
|
||||
// gRPC requests. The returned cancel function should be invoked by the caller
|
||||
// upon completion of the test.
|
||||
func Start() (*Server, func(), error) {
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, func() {}, fmt.Errorf("net.Listen() failed: %v", err)
|
||||
// Start makes a new Server which uses the provided net.Listener. If lis is nil,
|
||||
// it creates a new net.Listener on a local port. The returned cancel function
|
||||
// should be invoked by the caller upon completion of the test.
|
||||
func Start(lis net.Listener, opts ...grpc.ServerOption) (*Server, func(), error) {
|
||||
if lis == nil {
|
||||
var err error
|
||||
lis, err = net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, func() {}, fmt.Errorf("net.Listen() failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
// Give the channels a buffer size of 1 so that we can setup
|
||||
// expectations for one lookup call, without blocking.
|
||||
RequestChan: make(chan *rlspb.RouteLookupRequest, 1),
|
||||
RequestChan: testutils.NewChannelWithSize(defaultChannelBufferSize),
|
||||
ResponseChan: make(chan Response, 1),
|
||||
Address: lis.Addr().String(),
|
||||
}
|
||||
|
||||
server := grpc.NewServer()
|
||||
server := grpc.NewServer(opts...)
|
||||
rlsgrpc.RegisterRouteLookupServiceServer(server, s)
|
||||
go server.Serve(lis)
|
||||
|
||||
|
|
@ -74,9 +82,17 @@ func Start() (*Server, func(), error) {
|
|||
|
||||
// RouteLookup implements the RouteLookupService.
|
||||
func (s *Server) RouteLookup(ctx context.Context, req *rlspb.RouteLookupRequest) (*rlspb.RouteLookupResponse, error) {
|
||||
s.RequestChan <- req
|
||||
resp := <-s.ResponseChan
|
||||
return resp.Resp, resp.Err
|
||||
s.RequestChan.Send(req)
|
||||
|
||||
// The leakchecker fails if we don't exit out of here in a reasonable time.
|
||||
timer := time.NewTimer(defaultRPCTimeout)
|
||||
select {
|
||||
case <-timer.C:
|
||||
return nil, errors.New("default RPC timeout exceeded")
|
||||
case resp := <-s.ResponseChan:
|
||||
timer.Stop()
|
||||
return resp.Resp, resp.Err
|
||||
}
|
||||
}
|
||||
|
||||
// ClientConn returns a grpc.ClientConn connected to the fakeServer.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrRecvTimeout is an error to indicate that a receive operation on the
|
||||
// channel timed out.
|
||||
var ErrRecvTimeout = errors.New("timed out when waiting for value on channel")
|
||||
|
||||
const (
|
||||
// DefaultChanRecvTimeout is the default timeout for receive operations on the
|
||||
// underlying channel.
|
||||
DefaultChanRecvTimeout = 1 * time.Second
|
||||
// DefaultChanBufferSize is the default buffer size of the underlying channel.
|
||||
DefaultChanBufferSize = 1
|
||||
)
|
||||
|
||||
// Channel wraps a generic channel and provides a timed receive operation.
|
||||
type Channel struct {
|
||||
ch chan interface{}
|
||||
}
|
||||
|
||||
// Send sends value on the underlying channel.
|
||||
func (cwt *Channel) Send(value interface{}) {
|
||||
cwt.ch <- value
|
||||
}
|
||||
|
||||
// Receive returns the value received on the underlying channel, or
|
||||
// ErrRecvTimeout if DefaultChanRecvTimeout amount of time elapses.
|
||||
func (cwt *Channel) Receive() (interface{}, error) {
|
||||
timer := time.NewTimer(DefaultChanRecvTimeout)
|
||||
select {
|
||||
case <-timer.C:
|
||||
return nil, ErrRecvTimeout
|
||||
case got := <-cwt.ch:
|
||||
timer.Stop()
|
||||
return got, nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewChannel returns a new Channel.
|
||||
func NewChannel() *Channel {
|
||||
return NewChannelWithSize(DefaultChanBufferSize)
|
||||
}
|
||||
|
||||
// NewChannelWithSize returns a new Channel with a buffer of bufSize.
|
||||
func NewChannelWithSize(bufSize int) *Channel {
|
||||
return &Channel{ch: make(chan interface{}, bufSize)}
|
||||
}
|
||||
Loading…
Reference in New Issue