mirror of https://github.com/grpc/grpc-go.git
rls: control channel implementation (#5046)
This commit is contained in:
parent
7c8a9321b9
commit
50f82701b5
|
@ -19,183 +19,36 @@
|
|||
package rls
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/internal/grpcsync"
|
||||
)
|
||||
|
||||
var (
|
||||
_ balancer.Balancer = (*rlsBalancer)(nil)
|
||||
|
||||
// For overriding in tests.
|
||||
newRLSClientFunc = newRLSClient
|
||||
logger = grpclog.Component("rls")
|
||||
logger = grpclog.Component("rls")
|
||||
)
|
||||
|
||||
// rlsBalancer implements the RLS LB policy.
|
||||
type rlsBalancer struct {
|
||||
done *grpcsync.Event
|
||||
cc balancer.ClientConn
|
||||
opts balancer.BuildOptions
|
||||
type rlsBalancer struct{}
|
||||
|
||||
// Mutex protects all the state maintained by the LB policy.
|
||||
// TODO(easwars): Once we add the cache, we will also have another lock for
|
||||
// the cache alone.
|
||||
mu sync.Mutex
|
||||
lbCfg *lbConfig // Most recently received service config.
|
||||
rlsCC *grpc.ClientConn // ClientConn to the RLS server.
|
||||
rlsC *rlsClient // RLS client wrapper.
|
||||
|
||||
ccUpdateCh chan *balancer.ClientConnState
|
||||
}
|
||||
|
||||
// run is a long running goroutine which handles all the updates that the
|
||||
// balancer wishes to handle. The appropriate updateHandler will push the update
|
||||
// on to a channel that this goroutine will select on, thereby the handling of
|
||||
// the update will happen asynchronously.
|
||||
func (lb *rlsBalancer) run() {
|
||||
for {
|
||||
// TODO(easwars): Handle other updates like subConn state changes, RLS
|
||||
// responses from the server etc.
|
||||
select {
|
||||
case u := <-lb.ccUpdateCh:
|
||||
lb.handleClientConnUpdate(u)
|
||||
case <-lb.done.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleClientConnUpdate handles updates to the service config.
|
||||
// If the RLS server name or the RLS RPC timeout changes, it updates the control
|
||||
// channel accordingly.
|
||||
// TODO(easwars): Handle updates to other fields in the service config.
|
||||
func (lb *rlsBalancer) handleClientConnUpdate(ccs *balancer.ClientConnState) {
|
||||
logger.Infof("rls: service config: %+v", ccs.BalancerConfig)
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
if lb.done.HasFired() {
|
||||
logger.Warning("rls: received service config after balancer close")
|
||||
return
|
||||
}
|
||||
|
||||
newCfg := ccs.BalancerConfig.(*lbConfig)
|
||||
if lb.lbCfg.Equal(newCfg) {
|
||||
logger.Info("rls: new service config matches existing config")
|
||||
return
|
||||
}
|
||||
|
||||
lb.updateControlChannel(newCfg)
|
||||
lb.lbCfg = newCfg
|
||||
}
|
||||
|
||||
// UpdateClientConnState pushes the received ClientConnState update on the
|
||||
// update channel which will be processed asynchronously by the run goroutine.
|
||||
// Implements balancer.Balancer interface.
|
||||
func (lb *rlsBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
|
||||
select {
|
||||
case lb.ccUpdateCh <- &ccs:
|
||||
case <-lb.done.Done():
|
||||
}
|
||||
logger.Fatal("rls: UpdateClientConnState is not yet unimplemented")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResolverErr implements balancer.Balancer interface.
|
||||
func (lb *rlsBalancer) ResolverError(error) {
|
||||
// ResolverError is called by gRPC when the name resolver reports an error.
|
||||
// TODO(easwars): How do we handle this?
|
||||
logger.Fatal("rls: ResolverError is not yet unimplemented")
|
||||
}
|
||||
|
||||
// UpdateSubConnState implements balancer.Balancer interface.
|
||||
func (lb *rlsBalancer) UpdateSubConnState(_ balancer.SubConn, _ balancer.SubConnState) {
|
||||
logger.Fatal("rls: UpdateSubConnState is not yet implemented")
|
||||
}
|
||||
|
||||
// Cleans up the resources allocated by the LB policy including the clientConn
|
||||
// to the RLS server.
|
||||
// Implements balancer.Balancer.
|
||||
func (lb *rlsBalancer) Close() {
|
||||
lb.mu.Lock()
|
||||
defer lb.mu.Unlock()
|
||||
|
||||
lb.done.Fire()
|
||||
if lb.rlsCC != nil {
|
||||
lb.rlsCC.Close()
|
||||
}
|
||||
logger.Fatal("rls: Close is not yet implemented")
|
||||
}
|
||||
|
||||
func (lb *rlsBalancer) ExitIdle() {
|
||||
// TODO: are we 100% sure this should be a nop?
|
||||
}
|
||||
|
||||
// updateControlChannel updates the RLS client if required.
|
||||
// Caller must hold lb.mu.
|
||||
func (lb *rlsBalancer) updateControlChannel(newCfg *lbConfig) {
|
||||
oldCfg := lb.lbCfg
|
||||
if newCfg.lookupService == oldCfg.lookupService && newCfg.lookupServiceTimeout == oldCfg.lookupServiceTimeout {
|
||||
return
|
||||
}
|
||||
|
||||
// Use RPC timeout from new config, if different from existing one.
|
||||
timeout := oldCfg.lookupServiceTimeout
|
||||
if timeout != newCfg.lookupServiceTimeout {
|
||||
timeout = newCfg.lookupServiceTimeout
|
||||
}
|
||||
|
||||
if newCfg.lookupService == oldCfg.lookupService {
|
||||
// This is the case where only the timeout has changed. We will continue
|
||||
// to use the existing clientConn. but will create a new rlsClient with
|
||||
// the new timeout.
|
||||
lb.rlsC = newRLSClientFunc(lb.rlsCC, lb.opts.Target.Endpoint, timeout)
|
||||
return
|
||||
}
|
||||
|
||||
// This is the case where the RLS server name has changed. We need to create
|
||||
// a new clientConn and close the old one.
|
||||
var dopts []grpc.DialOption
|
||||
if dialer := lb.opts.Dialer; dialer != nil {
|
||||
dopts = append(dopts, grpc.WithContextDialer(dialer))
|
||||
}
|
||||
dopts = append(dopts, dialCreds(lb.opts))
|
||||
|
||||
cc, err := grpc.Dial(newCfg.lookupService, dopts...)
|
||||
if err != nil {
|
||||
logger.Errorf("rls: dialRLS(%s, %v): %v", newCfg.lookupService, lb.opts, err)
|
||||
// An error from a non-blocking dial indicates something serious. We
|
||||
// should continue to use the old control channel if one exists, and
|
||||
// return so that the rest of the config updates can be processes.
|
||||
return
|
||||
}
|
||||
if lb.rlsCC != nil {
|
||||
lb.rlsCC.Close()
|
||||
}
|
||||
lb.rlsCC = cc
|
||||
lb.rlsC = newRLSClientFunc(cc, lb.opts.Target.Endpoint, timeout)
|
||||
}
|
||||
|
||||
func dialCreds(opts balancer.BuildOptions) grpc.DialOption {
|
||||
// The control channel should use the same authority as that of the parent
|
||||
// channel. This ensures that the identify of the RLS server and that of the
|
||||
// backend is the same, so if the RLS config is injected by an attacker, it
|
||||
// cannot cause leakage of private information contained in headers set by
|
||||
// the application.
|
||||
server := opts.Target.Authority
|
||||
switch {
|
||||
case opts.DialCreds != nil:
|
||||
if err := opts.DialCreds.OverrideServerName(server); err != nil {
|
||||
logger.Warningf("rls: OverrideServerName(%s) = (%v), using Insecure", server, err)
|
||||
return grpc.WithInsecure()
|
||||
}
|
||||
return grpc.WithTransportCredentials(opts.DialCreds)
|
||||
case opts.CredsBundle != nil:
|
||||
return grpc.WithTransportCredentials(opts.CredsBundle.TransportCredentials())
|
||||
default:
|
||||
logger.Warning("rls: no credentials available, using Insecure")
|
||||
return grpc.WithInsecure()
|
||||
}
|
||||
logger.Fatal("rls: ExitIdle is not yet implemented")
|
||||
}
|
||||
|
|
|
@ -1,238 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package rls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/testdata"
|
||||
)
|
||||
|
||||
const defaultTestTimeout = 1 * time.Second
|
||||
|
||||
type s struct {
|
||||
grpctest.Tester
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
grpctest.RunSubTests(t, s{})
|
||||
}
|
||||
|
||||
type listenerWrapper struct {
|
||||
net.Listener
|
||||
connCh *testutils.Channel
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (l *listenerWrapper) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.connCh.Send(c)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func setupwithListener(t *testing.T, opts ...grpc.ServerOption) (*fakeserver.Server, *listenerWrapper, func()) {
|
||||
t.Helper()
|
||||
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen(tcp, localhost:0): %v", err)
|
||||
}
|
||||
lw := &listenerWrapper{
|
||||
Listener: l,
|
||||
connCh: testutils.NewChannel(),
|
||||
}
|
||||
|
||||
server, cleanup, err := fakeserver.Start(lw, opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("fakeserver.Start(): %v", err)
|
||||
}
|
||||
t.Logf("Fake RLS server started at %s ...", server.Address)
|
||||
|
||||
return server, lw, cleanup
|
||||
}
|
||||
|
||||
type testBalancerCC struct {
|
||||
balancer.ClientConn
|
||||
}
|
||||
|
||||
// TestUpdateControlChannelFirstConfig tests the scenario where the LB policy
|
||||
// receives its first service config and verifies that a control channel to the
|
||||
// RLS server specified in the serviceConfig is established.
|
||||
func (s) TestUpdateControlChannelFirstConfig(t *testing.T) {
|
||||
server, lis, cleanup := setupwithListener(t)
|
||||
defer cleanup()
|
||||
|
||||
bb := balancer.Get(rlsBalancerName)
|
||||
if bb == nil {
|
||||
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
|
||||
}
|
||||
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{})
|
||||
defer rlsB.Close()
|
||||
t.Log("Built RLS LB policy ...")
|
||||
|
||||
lbCfg := &lbConfig{lookupService: server.Address}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
if _, err := lis.connCh.Receive(ctx); err != nil {
|
||||
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
|
||||
}
|
||||
|
||||
// TODO: Verify channel connectivity state once control channel connectivity
|
||||
// state monitoring is in place.
|
||||
|
||||
// TODO: Verify RLS RPC can be made once we integrate with the picker.
|
||||
}
|
||||
|
||||
// TestUpdateControlChannelSwitch tests the scenario where a control channel
|
||||
// exists and the LB policy receives a new serviceConfig with a different RLS
|
||||
// server name. Verifies that the new control channel is created and the old one
|
||||
// is closed (the leakchecker takes care of this).
|
||||
func (s) TestUpdateControlChannelSwitch(t *testing.T) {
|
||||
server1, lis1, cleanup1 := setupwithListener(t)
|
||||
defer cleanup1()
|
||||
|
||||
server2, lis2, cleanup2 := setupwithListener(t)
|
||||
defer cleanup2()
|
||||
|
||||
bb := balancer.Get(rlsBalancerName)
|
||||
if bb == nil {
|
||||
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
|
||||
}
|
||||
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{})
|
||||
defer rlsB.Close()
|
||||
t.Log("Built RLS LB policy ...")
|
||||
|
||||
lbCfg := &lbConfig{lookupService: server1.Address}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
if _, err := lis1.connCh.Receive(ctx); err != nil {
|
||||
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
|
||||
}
|
||||
|
||||
lbCfg = &lbConfig{lookupService: server2.Address}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
|
||||
if _, err := lis2.connCh.Receive(ctx); err != nil {
|
||||
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
|
||||
}
|
||||
|
||||
// TODO: Verify channel connectivity state once control channel connectivity
|
||||
// state monitoring is in place.
|
||||
|
||||
// TODO: Verify RLS RPC can be made once we integrate with the picker.
|
||||
}
|
||||
|
||||
// TestUpdateControlChannelTimeout tests the scenario where the LB policy
|
||||
// receives a service config update with a different lookupServiceTimeout, but
|
||||
// the lookupService itself remains unchanged. It verifies that the LB policy
|
||||
// does not create a new control channel in this case.
|
||||
func (s) TestUpdateControlChannelTimeout(t *testing.T) {
|
||||
server, lis, cleanup := setupwithListener(t)
|
||||
defer cleanup()
|
||||
|
||||
bb := balancer.Get(rlsBalancerName)
|
||||
if bb == nil {
|
||||
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
|
||||
}
|
||||
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{})
|
||||
defer rlsB.Close()
|
||||
t.Log("Built RLS LB policy ...")
|
||||
|
||||
lbCfg := &lbConfig{lookupService: server.Address, lookupServiceTimeout: 1 * time.Second}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
if _, err := lis.connCh.Receive(ctx); err != nil {
|
||||
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
|
||||
}
|
||||
|
||||
lbCfg = &lbConfig{lookupService: server.Address, lookupServiceTimeout: 2 * time.Second}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
if _, err := lis.connCh.Receive(ctx); err != context.DeadlineExceeded {
|
||||
t.Fatal("LB policy created new control channel when only lookupServiceTimeout changed")
|
||||
}
|
||||
|
||||
// TODO: Verify channel connectivity state once control channel connectivity
|
||||
// state monitoring is in place.
|
||||
|
||||
// TODO: Verify RLS RPC can be made once we integrate with the picker.
|
||||
}
|
||||
|
||||
// TestUpdateControlChannelWithCreds tests the scenario where the control
|
||||
// channel is to established with credentials from the parent channel.
|
||||
func (s) TestUpdateControlChannelWithCreds(t *testing.T) {
|
||||
sCreds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
|
||||
if err != nil {
|
||||
t.Fatalf("credentials.NewServerTLSFromFile(server1.pem, server1.key) = %v", err)
|
||||
}
|
||||
cCreds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "")
|
||||
if err != nil {
|
||||
t.Fatalf("credentials.NewClientTLSFromFile(ca.pem) = %v", err)
|
||||
}
|
||||
|
||||
server, lis, cleanup := setupwithListener(t, grpc.Creds(sCreds))
|
||||
defer cleanup()
|
||||
|
||||
bb := balancer.Get(rlsBalancerName)
|
||||
if bb == nil {
|
||||
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
|
||||
}
|
||||
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{
|
||||
DialCreds: cCreds,
|
||||
})
|
||||
defer rlsB.Close()
|
||||
t.Log("Built RLS LB policy ...")
|
||||
|
||||
lbCfg := &lbConfig{lookupService: server.Address}
|
||||
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
|
||||
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
if _, err := lis.connCh.Receive(ctx); err != nil {
|
||||
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
|
||||
}
|
||||
|
||||
// TODO: Verify channel connectivity state once control channel connectivity
|
||||
// state monitoring is in place.
|
||||
|
||||
// TODO: Verify RLS RPC can be made once we integrate with the picker.
|
||||
}
|
|
@ -21,10 +21,9 @@ package rls
|
|||
|
||||
import (
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/internal/grpcsync"
|
||||
)
|
||||
|
||||
const rlsBalancerName = "rls"
|
||||
const rlsBalancerName = "rls_experimental"
|
||||
|
||||
func init() {
|
||||
balancer.Register(&rlsBB{})
|
||||
|
@ -41,13 +40,6 @@ func (*rlsBB) Name() string {
|
|||
}
|
||||
|
||||
func (*rlsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
|
||||
lb := &rlsBalancer{
|
||||
done: grpcsync.NewEvent(),
|
||||
cc: cc,
|
||||
opts: opts,
|
||||
lbCfg: &lbConfig{},
|
||||
ccUpdateCh: make(chan *balancer.ClientConnState),
|
||||
}
|
||||
go lb.run()
|
||||
return lb
|
||||
// TODO(easwars): Fix this once the LB policy implementation is pulled in.
|
||||
return &rlsBalancer{}
|
||||
}
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package rls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
rlsgrpc "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
|
||||
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
|
||||
)
|
||||
|
||||
// For gRPC services using RLS, the value of target_type in the
|
||||
// RouteLookupServiceRequest will be set to this.
|
||||
const grpcTargetType = "grpc"
|
||||
|
||||
// rlsClient is a simple wrapper around a RouteLookupService client which
|
||||
// provides non-blocking semantics on top of a blocking unary RPC call.
|
||||
//
|
||||
// The RLS LB policy creates a new rlsClient object with the following values:
|
||||
// * a grpc.ClientConn to the RLS server using appropriate credentials from the
|
||||
// parent channel
|
||||
// * dialTarget corresponding to the original user dial target, e.g.
|
||||
// "firestore.googleapis.com".
|
||||
//
|
||||
// The RLS LB policy uses an adaptive throttler to perform client side
|
||||
// throttling and asks this client to make an RPC call only after checking with
|
||||
// the throttler.
|
||||
type rlsClient struct {
|
||||
stub rlsgrpc.RouteLookupServiceClient
|
||||
// origDialTarget is the original dial target of the user and sent in each
|
||||
// RouteLookup RPC made to the RLS server.
|
||||
origDialTarget string
|
||||
// rpcTimeout specifies the timeout for the RouteLookup RPC call. The LB
|
||||
// policy receives this value in its service config.
|
||||
rpcTimeout time.Duration
|
||||
}
|
||||
|
||||
func newRLSClient(cc *grpc.ClientConn, dialTarget string, rpcTimeout time.Duration) *rlsClient {
|
||||
return &rlsClient{
|
||||
stub: rlsgrpc.NewRouteLookupServiceClient(cc),
|
||||
origDialTarget: dialTarget,
|
||||
rpcTimeout: rpcTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
type lookupCallback func(targets []string, headerData string, err error)
|
||||
|
||||
// lookup starts a RouteLookup RPC in a separate goroutine and returns the
|
||||
// results (and error, if any) in the provided callback.
|
||||
func (c *rlsClient) lookup(keyMap map[string]string, cb lookupCallback) {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.rpcTimeout)
|
||||
resp, err := c.stub.RouteLookup(ctx, &rlspb.RouteLookupRequest{
|
||||
// TODO(easwars): Use extra_keys field to populate host, service and
|
||||
// method keys.
|
||||
TargetType: grpcTargetType,
|
||||
KeyMap: keyMap,
|
||||
})
|
||||
cb(resp.GetTargets(), resp.GetHeaderData(), err)
|
||||
cancel()
|
||||
}()
|
||||
}
|
|
@ -1,178 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package rls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver"
|
||||
"google.golang.org/grpc/codes"
|
||||
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDialTarget = "dummy"
|
||||
defaultRPCTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
func setup(t *testing.T) (*fakeserver.Server, *grpc.ClientConn, func()) {
|
||||
t.Helper()
|
||||
|
||||
server, sCleanup, err := fakeserver.Start(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start fake RLS server: %v", err)
|
||||
}
|
||||
|
||||
cc, cCleanup, err := server.ClientConn()
|
||||
if err != nil {
|
||||
sCleanup()
|
||||
t.Fatalf("Failed to get a ClientConn to the RLS server: %v", err)
|
||||
}
|
||||
|
||||
return server, cc, func() {
|
||||
sCleanup()
|
||||
cCleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// TestLookupFailure verifies the case where the RLS server returns an error.
|
||||
func (s) TestLookupFailure(t *testing.T) {
|
||||
server, cc, cleanup := setup(t)
|
||||
defer cleanup()
|
||||
|
||||
// We setup the fake server to return an error.
|
||||
server.ResponseChan <- fakeserver.Response{Err: errors.New("rls failure")}
|
||||
|
||||
rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout)
|
||||
|
||||
errCh := testutils.NewChannel()
|
||||
rlsClient.lookup(nil, func(targets []string, headerData string, err error) {
|
||||
if err == nil {
|
||||
errCh.Send(errors.New("rlsClient.lookup() succeeded, should have failed"))
|
||||
return
|
||||
}
|
||||
if len(targets) != 0 || headerData != "" {
|
||||
errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (nil, \"\")", targets, headerData))
|
||||
return
|
||||
}
|
||||
errCh.Send(nil)
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
if e, err := errCh.Receive(ctx); err != nil || e != nil {
|
||||
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLookupDeadlineExceeded tests the case where the RPC deadline associated
|
||||
// with the lookup expires.
|
||||
func (s) TestLookupDeadlineExceeded(t *testing.T) {
|
||||
_, cc, cleanup := setup(t)
|
||||
defer cleanup()
|
||||
|
||||
// Give the Lookup RPC a small deadline, but don't setup the fake server to
|
||||
// return anything. So the Lookup call will block and eventually expire.
|
||||
rlsClient := newRLSClient(cc, defaultDialTarget, 100*time.Millisecond)
|
||||
|
||||
errCh := testutils.NewChannel()
|
||||
rlsClient.lookup(nil, func(_ []string, _ string, err error) {
|
||||
if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
|
||||
errCh.Send(fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded))
|
||||
return
|
||||
}
|
||||
errCh.Send(nil)
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
if e, err := errCh.Receive(ctx); err != nil || e != nil {
|
||||
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLookupSuccess verifies the successful Lookup API case.
|
||||
func (s) TestLookupSuccess(t *testing.T) {
|
||||
server, cc, cleanup := setup(t)
|
||||
defer cleanup()
|
||||
|
||||
const wantHeaderData = "headerData"
|
||||
|
||||
rlsReqKeyMap := map[string]string{
|
||||
"k1": "v1",
|
||||
"k2": "v2",
|
||||
}
|
||||
wantLookupRequest := &rlspb.RouteLookupRequest{
|
||||
// TODO(easwars): Use extra_keys field to populate host, service and
|
||||
// method keys.
|
||||
TargetType: "grpc",
|
||||
KeyMap: rlsReqKeyMap,
|
||||
}
|
||||
wantRespTargets := []string{"us_east_1.firestore.googleapis.com"}
|
||||
|
||||
rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout)
|
||||
|
||||
errCh := testutils.NewChannel()
|
||||
rlsClient.lookup(rlsReqKeyMap, func(targets []string, hd string, err error) {
|
||||
if err != nil {
|
||||
errCh.Send(fmt.Errorf("rlsClient.Lookup() failed: %v", err))
|
||||
return
|
||||
}
|
||||
if !cmp.Equal(targets, wantRespTargets) || hd != wantHeaderData {
|
||||
errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, hd, wantRespTargets, wantHeaderData))
|
||||
return
|
||||
}
|
||||
errCh.Send(nil)
|
||||
})
|
||||
|
||||
// Make sure that the fake server received the expected RouteLookupRequest
|
||||
// proto.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
req, err := server.RequestChan.Receive(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Timed out wile waiting for a RouteLookupRequest")
|
||||
}
|
||||
gotLookupRequest := req.(*rlspb.RouteLookupRequest)
|
||||
if diff := cmp.Diff(wantLookupRequest, gotLookupRequest, cmp.Comparer(proto.Equal)); diff != "" {
|
||||
t.Fatalf("RouteLookupRequest diff (-want, +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// We setup the fake server to return this response when it receives a
|
||||
// request.
|
||||
server.ResponseChan <- fakeserver.Response{
|
||||
Resp: &rlspb.RouteLookupResponse{
|
||||
Targets: wantRespTargets,
|
||||
HeaderData: wantHeaderData,
|
||||
},
|
||||
}
|
||||
|
||||
if e, err := errCh.Receive(ctx); err != nil || e != nil {
|
||||
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
|
||||
}
|
||||
}
|
|
@ -61,7 +61,7 @@ func testEqual(a, b *lbConfig) bool {
|
|||
childPolicyConfigEqual(a.childPolicyConfig, b.childPolicyConfig)
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
func (s) TestParseConfig(t *testing.T) {
|
||||
childPolicyTargetFieldVal, _ := json.Marshal(dummyChildPolicyTarget)
|
||||
tests := []struct {
|
||||
desc string
|
||||
|
@ -158,7 +158,7 @@ func TestParseConfig(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestParseConfigErrors(t *testing.T) {
|
||||
func (s) TestParseConfigErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
input []byte
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2021 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package rls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/rls/internal/adaptive"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/internal"
|
||||
internalgrpclog "google.golang.org/grpc/internal/grpclog"
|
||||
"google.golang.org/grpc/internal/pretty"
|
||||
rlsgrpc "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
|
||||
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
|
||||
)
|
||||
|
||||
var newAdaptiveThrottler = func() adaptiveThrottler { return adaptive.New() }
|
||||
|
||||
type adaptiveThrottler interface {
|
||||
ShouldThrottle() bool
|
||||
RegisterBackendResponse(throttled bool)
|
||||
}
|
||||
|
||||
// controlChannel is a wrapper around the gRPC channel to the RLS server
|
||||
// specified in the service config.
|
||||
type controlChannel struct {
|
||||
// rpcTimeout specifies the timeout for the RouteLookup RPC call. The LB
|
||||
// policy receives this value in its service config.
|
||||
rpcTimeout time.Duration
|
||||
// backToReadyCh is the channel on which an update is pushed when the
|
||||
// connectivity state changes from READY --> TRANSIENT_FAILURE --> READY.
|
||||
backToReadyCh chan struct{}
|
||||
// throttler in an adaptive throttling implementation used to avoid
|
||||
// hammering the RLS service while it is overloaded or down.
|
||||
throttler adaptiveThrottler
|
||||
|
||||
cc *grpc.ClientConn
|
||||
client rlsgrpc.RouteLookupServiceClient
|
||||
logger *internalgrpclog.PrefixLogger
|
||||
}
|
||||
|
||||
func newControlChannel(rlsServerName string, rpcTimeout time.Duration, bOpts balancer.BuildOptions, backToReadyCh chan struct{}) (*controlChannel, error) {
|
||||
ctrlCh := &controlChannel{
|
||||
rpcTimeout: rpcTimeout,
|
||||
backToReadyCh: backToReadyCh,
|
||||
throttler: newAdaptiveThrottler(),
|
||||
}
|
||||
ctrlCh.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[rls-control-channel %p] ", ctrlCh))
|
||||
|
||||
dopts, err := ctrlCh.dialOpts(bOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctrlCh.cc, err = grpc.Dial(rlsServerName, dopts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctrlCh.client = rlsgrpc.NewRouteLookupServiceClient(ctrlCh.cc)
|
||||
ctrlCh.logger.Infof("Control channel created to RLS server at: %v", rlsServerName)
|
||||
|
||||
go ctrlCh.monitorConnectivityState()
|
||||
return ctrlCh, nil
|
||||
}
|
||||
|
||||
// dialOpts constructs the dial options for the control plane channel.
|
||||
func (cc *controlChannel) dialOpts(bOpts balancer.BuildOptions) ([]grpc.DialOption, error) {
|
||||
// The control plane channel will use the same authority as the parent
|
||||
// channel for server authorization. This ensures that the identity of the
|
||||
// RLS server and the identity of the backends is the same, so if the RLS
|
||||
// config is injected by an attacker, it cannot cause leakage of private
|
||||
// information contained in headers set by the application.
|
||||
dopts := []grpc.DialOption{grpc.WithAuthority(bOpts.Authority)}
|
||||
if bOpts.Dialer != nil {
|
||||
dopts = append(dopts, grpc.WithContextDialer(bOpts.Dialer))
|
||||
}
|
||||
|
||||
// The control channel will use the channel credentials from the parent
|
||||
// channel, including any call creds associated with the channel creds.
|
||||
var credsOpt grpc.DialOption
|
||||
switch {
|
||||
case bOpts.DialCreds != nil:
|
||||
credsOpt = grpc.WithTransportCredentials(bOpts.DialCreds.Clone())
|
||||
case bOpts.CredsBundle != nil:
|
||||
// The "fallback" mode in google default credentials (which is the only
|
||||
// type of credentials we expect to be used with RLS) uses TLS/ALTS
|
||||
// creds for transport and uses the same call creds as that on the
|
||||
// parent bundle.
|
||||
bundle, err := bOpts.CredsBundle.NewWithMode(internal.CredsBundleModeFallback)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
credsOpt = grpc.WithCredentialsBundle(bundle)
|
||||
default:
|
||||
cc.logger.Warningf("no credentials available, using Insecure")
|
||||
credsOpt = grpc.WithInsecure()
|
||||
}
|
||||
return append(dopts, credsOpt), nil
|
||||
}
|
||||
|
||||
func (cc *controlChannel) monitorConnectivityState() {
|
||||
cc.logger.Infof("Starting connectivity state monitoring goroutine")
|
||||
// Since we use two mechanisms to deal with RLS server being down:
|
||||
// - adaptive throttling for the channel as a whole
|
||||
// - exponential backoff on a per-request basis
|
||||
// we need a way to avoid double-penalizing requests by counting failures
|
||||
// toward both mechanisms when the RLS server is unreachable.
|
||||
//
|
||||
// To accomplish this, we monitor the state of the control plane channel. If
|
||||
// the state has been TRANSIENT_FAILURE since the last time it was in state
|
||||
// READY, and it then transitions into state READY, we push on a channel
|
||||
// which is being read by the LB policy.
|
||||
//
|
||||
// The LB the policy will iterate through the cache to reset the backoff
|
||||
// timeouts in all cache entries. Specifically, this means that it will
|
||||
// reset the backoff state and cancel the pending backoff timer. Note that
|
||||
// when cancelling the backoff timer, just like when the backoff timer fires
|
||||
// normally, a new picker is returned to the channel, to force it to
|
||||
// re-process any wait-for-ready RPCs that may still be queued if we failed
|
||||
// them while we were in backoff. However, we should optimize this case by
|
||||
// returning only one new picker, regardless of how many backoff timers are
|
||||
// cancelled.
|
||||
|
||||
// Using the background context is fine here since we check for the ClientConn
|
||||
// entering SHUTDOWN and return early in that case.
|
||||
ctx := context.Background()
|
||||
|
||||
first := true
|
||||
for {
|
||||
// Wait for the control channel to become READY.
|
||||
for s := cc.cc.GetState(); s != connectivity.Ready; s = cc.cc.GetState() {
|
||||
if s == connectivity.Shutdown {
|
||||
return
|
||||
}
|
||||
cc.cc.WaitForStateChange(ctx, s)
|
||||
}
|
||||
cc.logger.Infof("Connectivity state is READY")
|
||||
|
||||
if !first {
|
||||
cc.logger.Infof("Control channel back to READY")
|
||||
cc.backToReadyCh <- struct{}{}
|
||||
}
|
||||
first = false
|
||||
|
||||
// Wait for the control channel to move out of READY.
|
||||
cc.cc.WaitForStateChange(ctx, connectivity.Ready)
|
||||
if cc.cc.GetState() == connectivity.Shutdown {
|
||||
return
|
||||
}
|
||||
cc.logger.Infof("Connectivity state is %s", cc.cc.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *controlChannel) close() {
|
||||
cc.logger.Infof("Closing control channel")
|
||||
cc.cc.Close()
|
||||
}
|
||||
|
||||
type lookupCallback func(targets []string, headerData string, err error)
|
||||
|
||||
// lookup starts a RouteLookup RPC in a separate goroutine and returns the
|
||||
// results (and error, if any) in the provided callback.
|
||||
//
|
||||
// The returned boolean indicates whether the request was throttled by the
|
||||
// client-side adaptive throttling algorithm in which case the provided callback
|
||||
// will not be invoked.
|
||||
func (cc *controlChannel) lookup(reqKeys map[string]string, reason rlspb.RouteLookupRequest_Reason, staleHeaders string, cb lookupCallback) (throttled bool) {
|
||||
if cc.throttler.ShouldThrottle() {
|
||||
cc.logger.Infof("RLS request throttled by client-side adaptive throttling")
|
||||
return true
|
||||
}
|
||||
go func() {
|
||||
req := &rlspb.RouteLookupRequest{
|
||||
TargetType: "grpc",
|
||||
KeyMap: reqKeys,
|
||||
Reason: reason,
|
||||
StaleHeaderData: staleHeaders,
|
||||
}
|
||||
cc.logger.Infof("Sending RLS request %+v", pretty.ToJSON(req))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cc.rpcTimeout)
|
||||
defer cancel()
|
||||
resp, err := cc.client.RouteLookup(ctx, req)
|
||||
cb(resp.GetTargets(), resp.GetHeaderData(), err)
|
||||
}()
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,469 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2021 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package rls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/rls/internal/test/e2e"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/internal"
|
||||
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/testdata"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// TestControlChannelThrottled tests the case where the adaptive throttler
|
||||
// indicates that the control channel needs to be throttled.
|
||||
func (s) TestControlChannelThrottled(t *testing.T) {
|
||||
// Start an RLS server and set the throttler to always throttle requests.
|
||||
rlsServer, rlsReqCh := setupFakeRLSServer(t, nil)
|
||||
overrideAdaptiveThrottler(t, alwaysThrottlingThrottler())
|
||||
|
||||
// Create a control channel to the fake RLS server.
|
||||
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create control channel to RLS server: %v", err)
|
||||
}
|
||||
defer ctrlCh.close()
|
||||
|
||||
// Perform the lookup and expect the attempt to be throttled.
|
||||
ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, nil)
|
||||
|
||||
select {
|
||||
case <-rlsReqCh:
|
||||
t.Fatal("RouteLookup RPC invoked when control channel is throtlled")
|
||||
case <-time.After(defaultTestShortTimeout):
|
||||
}
|
||||
}
|
||||
|
||||
// TestLookupFailure tests the case where the RLS server responds with an error.
|
||||
func (s) TestLookupFailure(t *testing.T) {
|
||||
// Start an RLS server and set the throttler to never throttle requests.
|
||||
rlsServer, _ := setupFakeRLSServer(t, nil)
|
||||
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
|
||||
|
||||
// Setup the RLS server to respond with errors.
|
||||
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *e2e.RouteLookupResponse {
|
||||
return &e2e.RouteLookupResponse{Err: errors.New("rls failure")}
|
||||
})
|
||||
|
||||
// Create a control channel to the fake RLS server.
|
||||
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create control channel to RLS server: %v", err)
|
||||
}
|
||||
defer ctrlCh.close()
|
||||
|
||||
// Perform the lookup and expect the callback to be invoked with an error.
|
||||
errCh := make(chan error, 1)
|
||||
ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
|
||||
if err == nil {
|
||||
errCh <- errors.New("rlsClient.lookup() succeeded, should have failed")
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
})
|
||||
|
||||
select {
|
||||
case <-time.After(defaultTestTimeout):
|
||||
t.Fatal("timeout when waiting for lookup callback to be invoked")
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestLookupDeadlineExceeded tests the case where the RLS server does not
|
||||
// respond within the configured rpc timeout.
|
||||
func (s) TestLookupDeadlineExceeded(t *testing.T) {
|
||||
// A unary interceptor which sleeps for long enough to cause lookup RPCs to
|
||||
// exceed their deadline.
|
||||
rlsReqCh := make(chan struct{}, 1)
|
||||
interceptor := func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
rlsReqCh <- struct{}{}
|
||||
time.Sleep(2 * defaultTestShortTimeout)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// Start an RLS server and set the throttler to never throttle.
|
||||
rlsServer, _ := setupFakeRLSServer(t, nil, grpc.UnaryInterceptor(interceptor))
|
||||
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
|
||||
|
||||
// Create a control channel with a small deadline.
|
||||
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestShortTimeout, balancer.BuildOptions{}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create control channel to RLS server: %v", err)
|
||||
}
|
||||
defer ctrlCh.close()
|
||||
|
||||
// Perform the lookup and expect the callback to be invoked with an error.
|
||||
errCh := make(chan error)
|
||||
ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
|
||||
if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
|
||||
errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded)
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
})
|
||||
|
||||
select {
|
||||
case <-time.After(defaultTestTimeout):
|
||||
t.Fatal("timeout when waiting for lookup callback to be invoked")
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testCredsBundle wraps a test call creds and real transport creds.
|
||||
type testCredsBundle struct {
|
||||
transportCreds credentials.TransportCredentials
|
||||
callCreds credentials.PerRPCCredentials
|
||||
}
|
||||
|
||||
func (f *testCredsBundle) TransportCredentials() credentials.TransportCredentials {
|
||||
return f.transportCreds
|
||||
}
|
||||
|
||||
func (f *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials {
|
||||
return f.callCreds
|
||||
}
|
||||
|
||||
func (f *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) {
|
||||
if mode != internal.CredsBundleModeFallback {
|
||||
return nil, fmt.Errorf("unsupported mode: %v", mode)
|
||||
}
|
||||
return &testCredsBundle{
|
||||
transportCreds: f.transportCreds,
|
||||
callCreds: f.callCreds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var (
|
||||
// Call creds sent by the testPerRPCCredentials on the client, and verified
|
||||
// by an interceptor on the server.
|
||||
perRPCCredsData = map[string]string{
|
||||
"test-key": "test-value",
|
||||
"test-key-bin": string([]byte{1, 2, 3}),
|
||||
}
|
||||
)
|
||||
|
||||
type testPerRPCCredentials struct {
|
||||
callCreds map[string]string
|
||||
}
|
||||
|
||||
func (f *testPerRPCCredentials) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
|
||||
return f.callCreds, nil
|
||||
}
|
||||
|
||||
func (f *testPerRPCCredentials) RequireTransportSecurity() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Unary server interceptor which validates if the RPC contains call credentials
|
||||
// which match `perRPCCredsData
|
||||
func callCredsValidatingServerInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.PermissionDenied, "didn't find metadata in context")
|
||||
}
|
||||
for k, want := range perRPCCredsData {
|
||||
got, ok := md[k]
|
||||
if !ok {
|
||||
return ctx, status.Errorf(codes.PermissionDenied, "didn't find call creds key %v in context", k)
|
||||
}
|
||||
if got[0] != want {
|
||||
return ctx, status.Errorf(codes.PermissionDenied, "for key %v, got value %v, want %v", k, got, want)
|
||||
}
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// makeTLSCreds is a test helper which creates a TLS based transport credentials
|
||||
// from files specified in the arguments.
|
||||
func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials.TransportCredentials {
|
||||
cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
|
||||
if err != nil {
|
||||
t.Fatalf("tls.LoadX509KeyPair(%q, %q) failed: %v", certPath, keyPath, err)
|
||||
}
|
||||
b, err := ioutil.ReadFile(testdata.Path(rootsPath))
|
||||
if err != nil {
|
||||
t.Fatalf("ioutil.ReadFile(%q) failed: %v", rootsPath, err)
|
||||
}
|
||||
roots := x509.NewCertPool()
|
||||
if !roots.AppendCertsFromPEM(b) {
|
||||
t.Fatal("failed to append certificates")
|
||||
}
|
||||
return credentials.NewTLS(&tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
RootCAs: roots,
|
||||
})
|
||||
}
|
||||
|
||||
const (
|
||||
wantHeaderData = "headerData"
|
||||
staleHeaderData = "staleHeaderData"
|
||||
)
|
||||
|
||||
var (
|
||||
keyMap = map[string]string{
|
||||
"k1": "v1",
|
||||
"k2": "v2",
|
||||
}
|
||||
wantTargets = []string{"us_east_1.firestore.googleapis.com"}
|
||||
lookupRequest = &rlspb.RouteLookupRequest{
|
||||
TargetType: "grpc",
|
||||
KeyMap: keyMap,
|
||||
Reason: rlspb.RouteLookupRequest_REASON_MISS,
|
||||
StaleHeaderData: staleHeaderData,
|
||||
}
|
||||
lookupResponse = &e2e.RouteLookupResponse{
|
||||
Resp: &rlspb.RouteLookupResponse{
|
||||
Targets: wantTargets,
|
||||
HeaderData: wantHeaderData,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func testControlChannelCredsSuccess(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions) {
|
||||
// Start an RLS server and set the throttler to never throttle requests.
|
||||
rlsServer, _ := setupFakeRLSServer(t, nil, sopts...)
|
||||
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
|
||||
|
||||
// Setup the RLS server to respond with a valid response.
|
||||
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *e2e.RouteLookupResponse {
|
||||
return lookupResponse
|
||||
})
|
||||
|
||||
// Verify that the request received by the RLS matches the expected one.
|
||||
rlsServer.SetRequestCallback(func(got *rlspb.RouteLookupRequest) {
|
||||
if diff := cmp.Diff(lookupRequest, got, cmp.Comparer(proto.Equal)); diff != "" {
|
||||
t.Errorf("RouteLookupRequest diff (-want, +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
// Create a control channel to the fake server.
|
||||
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, bopts, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create control channel to RLS server: %v", err)
|
||||
}
|
||||
defer ctrlCh.close()
|
||||
|
||||
// Perform the lookup and expect a successful callback invocation.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
errCh := make(chan error, 1)
|
||||
ctrlCh.lookup(keyMap, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(targets []string, headerData string, err error) {
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("rlsClient.lookup() failed with err: %v", err)
|
||||
return
|
||||
}
|
||||
if !cmp.Equal(targets, wantTargets) || headerData != wantHeaderData {
|
||||
errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, headerData, wantTargets, wantHeaderData)
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout when waiting for lookup callback to be invoked")
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestControlChannelCredsSuccess tests creation of the control channel with
|
||||
// different credentials, which are expected to succeed.
|
||||
func (s) TestControlChannelCredsSuccess(t *testing.T) {
|
||||
serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
|
||||
clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sopts []grpc.ServerOption
|
||||
bopts balancer.BuildOptions
|
||||
}{
|
||||
{
|
||||
name: "insecure",
|
||||
sopts: nil,
|
||||
bopts: balancer.BuildOptions{},
|
||||
},
|
||||
{
|
||||
name: "transport creds only",
|
||||
sopts: []grpc.ServerOption{grpc.Creds(serverCreds)},
|
||||
bopts: balancer.BuildOptions{
|
||||
DialCreds: clientCreds,
|
||||
Authority: "x.test.example.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "creds bundle",
|
||||
sopts: []grpc.ServerOption{
|
||||
grpc.Creds(serverCreds),
|
||||
grpc.UnaryInterceptor(callCredsValidatingServerInterceptor),
|
||||
},
|
||||
bopts: balancer.BuildOptions{
|
||||
CredsBundle: &testCredsBundle{
|
||||
transportCreds: clientCreds,
|
||||
callCreds: &testPerRPCCredentials{callCreds: perRPCCredsData},
|
||||
},
|
||||
Authority: "x.test.example.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testControlChannelCredsSuccess(t, test.sopts, test.bopts)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testControlChannelCredsFailure(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions, wantCode codes.Code, wantErr string) {
|
||||
// StartFakeRouteLookupServer a fake server.
|
||||
//
|
||||
// Start an RLS server and set the throttler to never throttle requests. The
|
||||
// creds failures happen before the RPC handler on the server is invoked.
|
||||
// So, there is need to setup the request and responses on the fake server.
|
||||
rlsServer, _ := setupFakeRLSServer(t, nil, sopts...)
|
||||
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
|
||||
|
||||
// Create the control channel to the fake server.
|
||||
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, bopts, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create control channel to RLS server: %v", err)
|
||||
}
|
||||
defer ctrlCh.close()
|
||||
|
||||
// Perform the lookup and expect the callback to be invoked with an error.
|
||||
errCh := make(chan error)
|
||||
ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
|
||||
if st, ok := status.FromError(err); !ok || st.Code() != wantCode || !strings.Contains(st.String(), wantErr) {
|
||||
errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, wantCode: %v, wantErr: %s", err, wantCode, wantErr)
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
})
|
||||
|
||||
select {
|
||||
case <-time.After(defaultTestTimeout):
|
||||
t.Fatal("timeout when waiting for lookup callback to be invoked")
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestControlChannelCredsFailure tests creation of the control channel with
|
||||
// different credentials, which are expected to fail.
|
||||
func (s) TestControlChannelCredsFailure(t *testing.T) {
|
||||
serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
|
||||
clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sopts []grpc.ServerOption
|
||||
bopts balancer.BuildOptions
|
||||
wantCode codes.Code
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "transport creds authority mismatch",
|
||||
sopts: []grpc.ServerOption{grpc.Creds(serverCreds)},
|
||||
bopts: balancer.BuildOptions{
|
||||
DialCreds: clientCreds,
|
||||
Authority: "authority-mismatch",
|
||||
},
|
||||
wantCode: codes.Unavailable,
|
||||
wantErr: "transport: authentication handshake failed: x509: certificate is valid for *.test.example.com, not authority-mismatch",
|
||||
},
|
||||
{
|
||||
name: "transport creds handshake failure",
|
||||
sopts: nil, // server expects insecure connection
|
||||
bopts: balancer.BuildOptions{
|
||||
DialCreds: clientCreds,
|
||||
Authority: "x.test.example.com",
|
||||
},
|
||||
wantCode: codes.Unavailable,
|
||||
wantErr: "transport: authentication handshake failed: tls: first record does not look like a TLS handshake",
|
||||
},
|
||||
{
|
||||
name: "call creds mismatch",
|
||||
sopts: []grpc.ServerOption{
|
||||
grpc.Creds(serverCreds),
|
||||
grpc.UnaryInterceptor(callCredsValidatingServerInterceptor), // server expects call creds
|
||||
},
|
||||
bopts: balancer.BuildOptions{
|
||||
CredsBundle: &testCredsBundle{
|
||||
transportCreds: clientCreds,
|
||||
callCreds: &testPerRPCCredentials{}, // sends no call creds
|
||||
},
|
||||
Authority: "x.test.example.com",
|
||||
},
|
||||
wantCode: codes.PermissionDenied,
|
||||
wantErr: "didn't find call creds",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testControlChannelCredsFailure(t, test.sopts, test.bopts, test.wantCode, test.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type unsupportedCredsBundle struct {
|
||||
credentials.Bundle
|
||||
}
|
||||
|
||||
func (*unsupportedCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) {
|
||||
return nil, fmt.Errorf("unsupported mode: %v", mode)
|
||||
}
|
||||
|
||||
// TestNewControlChannelUnsupportedCredsBundle tests the case where the control
|
||||
// channel is configured with a bundle which does not support the mode we use.
|
||||
func (s) TestNewControlChannelUnsupportedCredsBundle(t *testing.T) {
|
||||
rlsServer, _ := setupFakeRLSServer(t, nil)
|
||||
|
||||
// Create the control channel to the fake server.
|
||||
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{CredsBundle: &unsupportedCredsBundle{}}, nil)
|
||||
if err == nil {
|
||||
ctrlCh.close()
|
||||
t.Fatal("newControlChannel succeeded when expected to fail")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,327 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2021 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package rls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer/rls/internal/test/e2e"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/internal"
|
||||
"google.golang.org/grpc/internal/balancergroup"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
|
||||
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
|
||||
"google.golang.org/grpc/internal/stubserver"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/resolver/manual"
|
||||
"google.golang.org/grpc/serviceconfig"
|
||||
"google.golang.org/grpc/status"
|
||||
testgrpc "google.golang.org/grpc/test/grpc_testing"
|
||||
testpb "google.golang.org/grpc/test/grpc_testing"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
)
|
||||
|
||||
// TODO(easwars): Remove this once all RLS code is merged.
|
||||
//lint:file-ignore U1000 Ignore all unused code, not all code is merged yet.
|
||||
|
||||
const (
|
||||
defaultTestTimeout = 5 * time.Second
|
||||
defaultTestShortTimeout = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
func init() {
|
||||
balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond
|
||||
}
|
||||
|
||||
type s struct {
|
||||
grpctest.Tester
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
grpctest.RunSubTests(t, s{})
|
||||
}
|
||||
|
||||
// connWrapper wraps a net.Conn and pushes on a channel when closed.
|
||||
type connWrapper struct {
|
||||
net.Conn
|
||||
closeCh *testutils.Channel
|
||||
}
|
||||
|
||||
func (cw *connWrapper) Close() error {
|
||||
err := cw.Conn.Close()
|
||||
cw.closeCh.Replace(nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// listenerWrapper wraps a net.Listener and the returned net.Conn.
|
||||
//
|
||||
// It pushes on a channel whenever it accepts a new connection.
|
||||
type listenerWrapper struct {
|
||||
net.Listener
|
||||
newConnCh *testutils.Channel
|
||||
}
|
||||
|
||||
func (l *listenerWrapper) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
closeCh := testutils.NewChannel()
|
||||
conn := &connWrapper{Conn: c, closeCh: closeCh}
|
||||
l.newConnCh.Send(conn)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func newListenerWrapper(t *testing.T, lis net.Listener) *listenerWrapper {
|
||||
if lis == nil {
|
||||
var err error
|
||||
lis, err = testutils.LocalTCPListener()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &listenerWrapper{
|
||||
Listener: lis,
|
||||
newConnCh: testutils.NewChannel(),
|
||||
}
|
||||
}
|
||||
|
||||
// fakeBackoffStrategy is a fake implementation of the backoff.Strategy
|
||||
// interface, for tests to inject the backoff duration.
|
||||
type fakeBackoffStrategy struct {
|
||||
backoff time.Duration
|
||||
}
|
||||
|
||||
func (f *fakeBackoffStrategy) Backoff(retries int) time.Duration {
|
||||
return f.backoff
|
||||
}
|
||||
|
||||
// fakeThrottler is a fake implementation of the adaptiveThrottler interface.
|
||||
type fakeThrottler struct {
|
||||
throttleFunc func() bool
|
||||
}
|
||||
|
||||
func (f *fakeThrottler) ShouldThrottle() bool { return f.throttleFunc() }
|
||||
func (f *fakeThrottler) RegisterBackendResponse(bool) {}
|
||||
|
||||
// alwaysThrottlingThrottler returns a fake throttler which always throttles.
|
||||
func alwaysThrottlingThrottler() *fakeThrottler {
|
||||
return &fakeThrottler{throttleFunc: func() bool { return true }}
|
||||
}
|
||||
|
||||
// neverThrottlingThrottler returns a fake throttler which never throttles.
|
||||
func neverThrottlingThrottler() *fakeThrottler {
|
||||
return &fakeThrottler{throttleFunc: func() bool { return false }}
|
||||
}
|
||||
|
||||
// oneTimeAllowingThrottler returns a fake throttler which does not throttle the
|
||||
// first request, but throttles everything that comes after. This is useful for
|
||||
// tests which need to set up a valid cache entry before testing other cases.
|
||||
func oneTimeAllowingThrottler() *fakeThrottler {
|
||||
var once sync.Once
|
||||
return &fakeThrottler{
|
||||
throttleFunc: func() bool {
|
||||
throttle := true
|
||||
once.Do(func() { throttle = false })
|
||||
return throttle
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func overrideAdaptiveThrottler(t *testing.T, f *fakeThrottler) {
|
||||
origAdaptiveThrottler := newAdaptiveThrottler
|
||||
newAdaptiveThrottler = func() adaptiveThrottler { return f }
|
||||
t.Cleanup(func() { newAdaptiveThrottler = origAdaptiveThrottler })
|
||||
}
|
||||
|
||||
// setupFakeRLSServer starts and returns a fake RouteLookupService server
|
||||
// listening on the given listener or on a random local port. Also returns a
|
||||
// channel for tests to get notified whenever the RouteLookup RPC is invoked on
|
||||
// the fake server.
|
||||
//
|
||||
// This function sets up the fake server to respond with an empty response for
|
||||
// the RouteLookup RPCs. Tests can override this by calling the
|
||||
// SetResponseCallback() method on the returned fake server.
|
||||
func setupFakeRLSServer(t *testing.T, lis net.Listener, opts ...grpc.ServerOption) (*e2e.FakeRouteLookupServer, chan struct{}) {
|
||||
s, cancel := e2e.StartFakeRouteLookupServer(t, lis, opts...)
|
||||
t.Logf("Started fake RLS server at %q", s.Address)
|
||||
|
||||
ch := make(chan struct{}, 1)
|
||||
s.SetRequestCallback(func(request *rlspb.RouteLookupRequest) {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
})
|
||||
t.Cleanup(cancel)
|
||||
return s, ch
|
||||
}
|
||||
|
||||
// buildBasicRLSConfig constructs a basic service config for the RLS LB policy
|
||||
// which header matching rules. This expects the passed child policy name to
|
||||
// have been registered by the caller.
|
||||
func buildBasicRLSConfig(childPolicyName, rlsServerAddress string) *e2e.RLSConfig {
|
||||
return &e2e.RLSConfig{
|
||||
RouteLookupConfig: &rlspb.RouteLookupConfig{
|
||||
GrpcKeybuilders: []*rlspb.GrpcKeyBuilder{
|
||||
{
|
||||
Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "grpc.testing.TestService"}},
|
||||
Headers: []*rlspb.NameMatcher{
|
||||
{Key: "k1", Names: []string{"n1"}},
|
||||
{Key: "k2", Names: []string{"n2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
LookupService: rlsServerAddress,
|
||||
LookupServiceTimeout: durationpb.New(defaultTestTimeout),
|
||||
CacheSizeBytes: 1024,
|
||||
},
|
||||
ChildPolicy: &internalserviceconfig.BalancerConfig{Name: childPolicyName},
|
||||
ChildPolicyConfigTargetFieldName: e2e.RLSChildPolicyTargetNameField,
|
||||
}
|
||||
}
|
||||
|
||||
// buildBasicRLSConfigWithChildPolicy constructs a very basic service config for
|
||||
// the RLS LB policy. It also registers a test LB policy which is capable of
|
||||
// being a child of the RLS LB policy.
|
||||
func buildBasicRLSConfigWithChildPolicy(t *testing.T, childPolicyName, rlsServerAddress string) *e2e.RLSConfig {
|
||||
childPolicyName = "test-child-policy" + childPolicyName
|
||||
e2e.RegisterRLSChildPolicy(childPolicyName, nil)
|
||||
t.Logf("Registered child policy with name %q", childPolicyName)
|
||||
|
||||
return &e2e.RLSConfig{
|
||||
RouteLookupConfig: &rlspb.RouteLookupConfig{
|
||||
GrpcKeybuilders: []*rlspb.GrpcKeyBuilder{{Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "grpc.testing.TestService"}}}},
|
||||
LookupService: rlsServerAddress,
|
||||
LookupServiceTimeout: durationpb.New(defaultTestTimeout),
|
||||
CacheSizeBytes: 1024,
|
||||
},
|
||||
ChildPolicy: &internalserviceconfig.BalancerConfig{Name: childPolicyName},
|
||||
ChildPolicyConfigTargetFieldName: e2e.RLSChildPolicyTargetNameField,
|
||||
}
|
||||
}
|
||||
|
||||
// startBackend starts a backend implementing the TestService on a local port.
|
||||
// It returns a channel for tests to get notified whenever an RPC is invoked on
|
||||
// the backend. This allows tests to ensure that RPCs reach expected backends.
|
||||
// Also returns the address of the backend.
|
||||
func startBackend(t *testing.T, sopts ...grpc.ServerOption) (rpcCh chan struct{}, address string) {
|
||||
t.Helper()
|
||||
|
||||
rpcCh = make(chan struct{}, 1)
|
||||
backend := &stubserver.StubServer{
|
||||
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
select {
|
||||
case rpcCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return &testpb.Empty{}, nil
|
||||
},
|
||||
}
|
||||
if err := backend.StartServer(sopts...); err != nil {
|
||||
t.Fatalf("Failed to start backend: %v", err)
|
||||
}
|
||||
t.Logf("Started TestService backend at: %q", backend.Address)
|
||||
t.Cleanup(func() { backend.Stop() })
|
||||
return rpcCh, backend.Address
|
||||
}
|
||||
|
||||
// startManualResolverWithConfig registers and returns a manual resolver which
|
||||
// pushes the RLS LB policy's service config on the channel.
|
||||
func startManualResolverWithConfig(t *testing.T, rlsConfig *e2e.RLSConfig) *manual.Resolver {
|
||||
t.Helper()
|
||||
|
||||
scJSON, err := rlsConfig.ServiceConfigJSON()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sc := internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)(scJSON)
|
||||
r := manual.NewBuilderWithScheme("rls-e2e")
|
||||
r.InitialState(resolver.State{ServiceConfig: sc})
|
||||
t.Cleanup(r.Close)
|
||||
return r
|
||||
}
|
||||
|
||||
// makeTestRPCAndExpectItToReachBackend is a test helper function which makes
|
||||
// the EmptyCall RPC on the given ClientConn and verifies that it reaches a
|
||||
// backend. The latter is accomplished by listening on the provided channel
|
||||
// which gets pushed to whenever the backend in question gets an RPC.
|
||||
func makeTestRPCAndExpectItToReachBackend(ctx context.Context, t *testing.T, cc *grpc.ClientConn, ch chan struct{}) {
|
||||
t.Helper()
|
||||
|
||||
client := testgrpc.NewTestServiceClient(cc)
|
||||
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
|
||||
t.Fatalf("TestService/EmptyCall() failed with error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("Timeout when waiting for backend to receive RPC")
|
||||
case <-ch:
|
||||
}
|
||||
}
|
||||
|
||||
// makeTestRPCAndVerifyError is a test helper function which makes the EmptyCall
|
||||
// RPC on the given ClientConn and verifies that the RPC fails with the given
|
||||
// status code and error.
|
||||
func makeTestRPCAndVerifyError(ctx context.Context, t *testing.T, cc *grpc.ClientConn, wantCode codes.Code, wantErr error) {
|
||||
t.Helper()
|
||||
|
||||
client := testgrpc.NewTestServiceClient(cc)
|
||||
_, err := client.EmptyCall(ctx, &testpb.Empty{})
|
||||
if err == nil {
|
||||
t.Fatal("TestService/EmptyCall() succeeded when expected to fail")
|
||||
}
|
||||
if code := status.Code(err); code != wantCode {
|
||||
t.Fatalf("TestService/EmptyCall() returned code: %v, want: %v", code, wantCode)
|
||||
}
|
||||
if wantErr != nil && !strings.Contains(err.Error(), wantErr.Error()) {
|
||||
t.Fatalf("TestService/EmptyCall() returned err: %v, want: %v", err, wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
// verifyRLSRequest is a test helper which listens on a channel to see if an RLS
|
||||
// request was received by the fake RLS server. Based on whether the test
|
||||
// expects a request to be sent out or not, it uses a different timeout.
|
||||
func verifyRLSRequest(t *testing.T, ch chan struct{}, wantRequest bool) {
|
||||
t.Helper()
|
||||
|
||||
if wantRequest {
|
||||
select {
|
||||
case <-time.After(defaultTestTimeout):
|
||||
t.Fatalf("Timeout when waiting for an RLS request to be sent out")
|
||||
case <-ch:
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-time.After(defaultTestShortTimeout):
|
||||
case <-ch:
|
||||
t.Fatalf("RLS request sent out when not expecting one")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,147 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package rls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/rls/internal/cache"
|
||||
"google.golang.org/grpc/balancer/rls/internal/keys"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
var errRLSThrottled = errors.New("RLS call throttled at client side")
|
||||
|
||||
// RLS rlsPicker selects the subConn to be used for a particular RPC. It does
|
||||
// not manage subConns directly and usually deletegates to pickers provided by
|
||||
// child policies.
|
||||
//
|
||||
// The RLS LB policy creates a new rlsPicker object whenever its ServiceConfig
|
||||
// is updated and provides a bunch of hooks for the rlsPicker to get the latest
|
||||
// state that it can used to make its decision.
|
||||
type rlsPicker struct {
|
||||
// The keyBuilder map used to generate RLS keys for the RPC. This is built
|
||||
// by the LB policy based on the received ServiceConfig.
|
||||
kbm keys.BuilderMap
|
||||
// Endpoint from the user's original dial target. Used to set the `host_key`
|
||||
// field in `extra_keys`.
|
||||
origEndpoint string
|
||||
|
||||
// The following hooks are setup by the LB policy to enable the rlsPicker to
|
||||
// access state stored in the policy. This approach has the following
|
||||
// advantages:
|
||||
// 1. The rlsPicker is loosely coupled with the LB policy in the sense that
|
||||
// updates happening on the LB policy like the receipt of an RLS
|
||||
// response, or an update to the default rlsPicker etc are not explicitly
|
||||
// pushed to the rlsPicker, but are readily available to the rlsPicker
|
||||
// when it invokes these hooks. And the LB policy takes care of
|
||||
// synchronizing access to these shared state.
|
||||
// 2. It makes unit testing the rlsPicker easy since any number of these
|
||||
// hooks could be overridden.
|
||||
|
||||
// readCache is used to read from the data cache and the pending request
|
||||
// map in an atomic fashion. The first return parameter is the entry in the
|
||||
// data cache, and the second indicates whether an entry for the same key
|
||||
// is present in the pending cache.
|
||||
readCache func(cache.Key) (*cache.Entry, bool)
|
||||
// shouldThrottle decides if the current RPC should be throttled at the
|
||||
// client side. It uses an adaptive throttling algorithm.
|
||||
shouldThrottle func() bool
|
||||
// startRLS kicks off an RLS request in the background for the provided RPC
|
||||
// path and keyMap. An entry in the pending request map is created before
|
||||
// sending out the request and an entry in the data cache is created or
|
||||
// updated upon receipt of a response. See implementation in the LB policy
|
||||
// for details.
|
||||
startRLS func(string, keys.KeyMap)
|
||||
// defaultPick enables the rlsPicker to delegate the pick decision to the
|
||||
// rlsPicker returned by the child LB policy pointing to the default target
|
||||
// specified in the service config.
|
||||
defaultPick func(balancer.PickInfo) (balancer.PickResult, error)
|
||||
}
|
||||
|
||||
// Pick makes the routing decision for every outbound RPC.
|
||||
func (p *rlsPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
|
||||
// Build the request's keys using the key builders from LB config.
|
||||
md, _ := metadata.FromOutgoingContext(info.Ctx)
|
||||
km := p.kbm.RLSKey(md, p.origEndpoint, info.FullMethodName)
|
||||
|
||||
// We use the LB policy hook to read the data cache and the pending request
|
||||
// map (whether or not an entry exists) for the RPC path and the generated
|
||||
// RLS keys. We will end up kicking off an RLS request only if there is no
|
||||
// pending request for the current RPC path and keys, and either we didn't
|
||||
// find an entry in the data cache or the entry was stale and it wasn't in
|
||||
// backoff.
|
||||
startRequest := false
|
||||
now := time.Now()
|
||||
entry, pending := p.readCache(cache.Key{Path: info.FullMethodName, KeyMap: km.Str})
|
||||
if entry == nil {
|
||||
startRequest = true
|
||||
} else {
|
||||
entry.Mu.Lock()
|
||||
defer entry.Mu.Unlock()
|
||||
if entry.StaleTime.Before(now) && entry.BackoffTime.Before(now) {
|
||||
// This is the proactive cache refresh.
|
||||
startRequest = true
|
||||
}
|
||||
}
|
||||
|
||||
if startRequest && !pending {
|
||||
if p.shouldThrottle() {
|
||||
// The entry doesn't exist or has expired and the new RLS request
|
||||
// has been throttled. Treat it as an error and delegate to default
|
||||
// pick, if one exists, or fail the pick.
|
||||
if entry == nil || entry.ExpiryTime.Before(now) {
|
||||
if p.defaultPick != nil {
|
||||
return p.defaultPick(info)
|
||||
}
|
||||
return balancer.PickResult{}, errRLSThrottled
|
||||
}
|
||||
// The proactive refresh has been throttled. Nothing to worry, just
|
||||
// keep using the existing entry.
|
||||
} else {
|
||||
p.startRLS(info.FullMethodName, km)
|
||||
}
|
||||
}
|
||||
|
||||
if entry != nil {
|
||||
if entry.ExpiryTime.After(now) {
|
||||
// This is the jolly good case where we have found a valid entry in
|
||||
// the data cache. We delegate to the LB policy associated with
|
||||
// this cache entry.
|
||||
return entry.ChildPicker.Pick(info)
|
||||
} else if entry.BackoffTime.After(now) {
|
||||
// The entry has expired, but is in backoff. We delegate to the
|
||||
// default pick, if one exists, or return the error from the last
|
||||
// failed RLS request for this entry.
|
||||
if p.defaultPick != nil {
|
||||
return p.defaultPick(info)
|
||||
}
|
||||
return balancer.PickResult{}, entry.CallStatus
|
||||
}
|
||||
}
|
||||
|
||||
// We get here only in the following cases:
|
||||
// * No data cache entry or expired entry, RLS request sent out
|
||||
// * No valid data cache entry and Pending cache entry exists
|
||||
// We need to queue to pick which will be handled once the RLS response is
|
||||
// received.
|
||||
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
|
||||
}
|
|
@ -1,615 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package rls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/rls/internal/cache"
|
||||
"google.golang.org/grpc/balancer/rls/internal/keys"
|
||||
"google.golang.org/grpc/internal/grpcrand"
|
||||
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
const defaultTestMaxAge = 5 * time.Second
|
||||
|
||||
// initKeyBuilderMap initializes a keyBuilderMap of the form:
|
||||
// {
|
||||
// "gFoo": "k1=n1",
|
||||
// "gBar/method1": "k2=n21,n22"
|
||||
// "gFoobar": "k3=n3",
|
||||
// }
|
||||
func initKeyBuilderMap() (keys.BuilderMap, error) {
|
||||
kb1 := &rlspb.GrpcKeyBuilder{
|
||||
Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "gFoo"}},
|
||||
Headers: []*rlspb.NameMatcher{{Key: "k1", Names: []string{"n1"}}},
|
||||
}
|
||||
kb2 := &rlspb.GrpcKeyBuilder{
|
||||
Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "gBar", Method: "method1"}},
|
||||
Headers: []*rlspb.NameMatcher{{Key: "k2", Names: []string{"n21", "n22"}}},
|
||||
}
|
||||
kb3 := &rlspb.GrpcKeyBuilder{
|
||||
Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "gFoobar"}},
|
||||
Headers: []*rlspb.NameMatcher{{Key: "k3", Names: []string{"n3"}}},
|
||||
}
|
||||
return keys.MakeBuilderMap(&rlspb.RouteLookupConfig{
|
||||
GrpcKeybuilders: []*rlspb.GrpcKeyBuilder{kb1, kb2, kb3},
|
||||
})
|
||||
}
|
||||
|
||||
// fakeSubConn embeds the balancer.SubConn interface and contains an id which
|
||||
// helps verify that the expected subConn was returned by the rlsPicker.
|
||||
type fakeSubConn struct {
|
||||
balancer.SubConn
|
||||
id int
|
||||
}
|
||||
|
||||
// fakePicker sends a PickResult with a fakeSubConn with the configured id.
|
||||
type fakePicker struct {
|
||||
id int
|
||||
}
|
||||
|
||||
func (p *fakePicker) Pick(_ balancer.PickInfo) (balancer.PickResult, error) {
|
||||
return balancer.PickResult{SubConn: &fakeSubConn{id: p.id}}, nil
|
||||
}
|
||||
|
||||
// newFakePicker returns a fakePicker configured with a random ID. The subConns
|
||||
// returned by this picker are of type fakefakeSubConn, and contain the same
|
||||
// random ID, which tests can use to verify.
|
||||
func newFakePicker() *fakePicker {
|
||||
return &fakePicker{id: grpcrand.Intn(math.MaxInt32)}
|
||||
}
|
||||
|
||||
func verifySubConn(sc balancer.SubConn, wantID int) error {
|
||||
fsc, ok := sc.(*fakeSubConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("Pick() returned a SubConn of type %T, want %T", sc, &fakeSubConn{})
|
||||
}
|
||||
if fsc.id != wantID {
|
||||
return fmt.Errorf("Pick() returned SubConn %d, want %d", fsc.id, wantID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestPickKeyBuilder verifies the different possible scenarios for forming an
|
||||
// RLS key for an incoming RPC.
|
||||
func TestPickKeyBuilder(t *testing.T) {
|
||||
kbm, err := initKeyBuilderMap()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create keyBuilderMap: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
rpcPath string
|
||||
md metadata.MD
|
||||
wantKey cache.Key
|
||||
}{
|
||||
{
|
||||
desc: "non existent service in keyBuilder map",
|
||||
rpcPath: "/gNonExistentService/method",
|
||||
md: metadata.New(map[string]string{"n1": "v1", "n3": "v3"}),
|
||||
wantKey: cache.Key{Path: "/gNonExistentService/method", KeyMap: ""},
|
||||
},
|
||||
{
|
||||
desc: "no metadata in incoming context",
|
||||
rpcPath: "/gFoo/method",
|
||||
md: metadata.MD{},
|
||||
wantKey: cache.Key{Path: "/gFoo/method", KeyMap: ""},
|
||||
},
|
||||
{
|
||||
desc: "keyBuilderMatch",
|
||||
rpcPath: "/gFoo/method",
|
||||
md: metadata.New(map[string]string{"n1": "v1", "n3": "v3"}),
|
||||
wantKey: cache.Key{Path: "/gFoo/method", KeyMap: "k1=v1"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
randID := grpcrand.Intn(math.MaxInt32)
|
||||
p := rlsPicker{
|
||||
kbm: kbm,
|
||||
readCache: func(key cache.Key) (*cache.Entry, bool) {
|
||||
if !cmp.Equal(key, test.wantKey) {
|
||||
t.Fatalf("rlsPicker using cacheKey %v, want %v", key, test.wantKey)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
return &cache.Entry{
|
||||
ExpiryTime: now.Add(defaultTestMaxAge),
|
||||
StaleTime: now.Add(defaultTestMaxAge),
|
||||
// Cache entry is configured with a child policy whose
|
||||
// rlsPicker always returns an empty PickResult and nil
|
||||
// error.
|
||||
ChildPicker: &fakePicker{id: randID},
|
||||
}, false
|
||||
},
|
||||
// The other hooks are not set here because they are not expected to be
|
||||
// invoked for these cases and if they get invoked, they will panic.
|
||||
}
|
||||
|
||||
gotResult, err := p.Pick(balancer.PickInfo{
|
||||
FullMethodName: test.rpcPath,
|
||||
Ctx: metadata.NewOutgoingContext(context.Background(), test.md),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() failed with error: %v", err)
|
||||
}
|
||||
sc, ok := gotResult.SubConn.(*fakeSubConn)
|
||||
if !ok {
|
||||
t.Fatalf("Pick() returned a SubConn of type %T, want %T", gotResult.SubConn, &fakeSubConn{})
|
||||
}
|
||||
if sc.id != randID {
|
||||
t.Fatalf("Pick() returned SubConn %d, want %d", sc.id, randID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPick_DataCacheMiss_PendingCacheMiss verifies different Pick scenarios
|
||||
// where the entry is neither found in the data cache nor in the pending cache.
|
||||
func TestPick_DataCacheMiss_PendingCacheMiss(t *testing.T) {
|
||||
const (
|
||||
rpcPath = "/gFoo/method"
|
||||
wantKeyMapStr = "k1=v1"
|
||||
)
|
||||
kbm, err := initKeyBuilderMap()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create keyBuilderMap: %v", err)
|
||||
}
|
||||
md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"})
|
||||
wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
// Whether or not a default target is configured.
|
||||
defaultPickExists bool
|
||||
// Whether or not the RLS request should be throttled.
|
||||
throttle bool
|
||||
// Whether or not the test is expected to make a new RLS request.
|
||||
wantRLSRequest bool
|
||||
// Expected error returned by the rlsPicker under test.
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
desc: "rls request throttled with default pick",
|
||||
defaultPickExists: true,
|
||||
throttle: true,
|
||||
},
|
||||
{
|
||||
desc: "rls request throttled without default pick",
|
||||
throttle: true,
|
||||
wantErr: errRLSThrottled,
|
||||
},
|
||||
{
|
||||
desc: "rls request not throttled",
|
||||
wantRLSRequest: true,
|
||||
wantErr: balancer.ErrNoSubConnAvailable,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
rlsCh := testutils.NewChannel()
|
||||
defaultPicker := newFakePicker()
|
||||
|
||||
p := rlsPicker{
|
||||
kbm: kbm,
|
||||
// Cache lookup fails, no pending entry.
|
||||
readCache: func(key cache.Key) (*cache.Entry, bool) {
|
||||
if !cmp.Equal(key, wantKey) {
|
||||
t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey)
|
||||
}
|
||||
return nil, false
|
||||
},
|
||||
shouldThrottle: func() bool { return test.throttle },
|
||||
startRLS: func(path string, km keys.KeyMap) {
|
||||
if !test.wantRLSRequest {
|
||||
rlsCh.Send(errors.New("RLS request attempted when none was expected"))
|
||||
return
|
||||
}
|
||||
if path != rpcPath {
|
||||
rlsCh.Send(fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath))
|
||||
return
|
||||
}
|
||||
if km.Str != wantKeyMapStr {
|
||||
rlsCh.Send(fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr))
|
||||
return
|
||||
}
|
||||
rlsCh.Send(nil)
|
||||
},
|
||||
}
|
||||
if test.defaultPickExists {
|
||||
p.defaultPick = defaultPicker.Pick
|
||||
}
|
||||
|
||||
gotResult, err := p.Pick(balancer.PickInfo{
|
||||
FullMethodName: rpcPath,
|
||||
Ctx: metadata.NewOutgoingContext(context.Background(), md),
|
||||
})
|
||||
if err != test.wantErr {
|
||||
t.Fatalf("Pick() returned error {%v}, want {%v}", err, test.wantErr)
|
||||
}
|
||||
// If the test specified that a new RLS request should be made,
|
||||
// verify it.
|
||||
if test.wantRLSRequest {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
if rlsErr, err := rlsCh.Receive(ctx); err != nil || rlsErr != nil {
|
||||
t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err)
|
||||
}
|
||||
}
|
||||
if test.wantErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// We get here only for cases where we expect the pick to be
|
||||
// delegated to the default picker.
|
||||
if err := verifySubConn(gotResult.SubConn, defaultPicker.id); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPick_DataCacheMiss_PendingCacheMiss verifies different Pick scenarios
|
||||
// where the entry is not found in the data cache, but there is a entry in the
|
||||
// pending cache. For all of these scenarios, no new RLS request will be sent.
|
||||
func TestPick_DataCacheMiss_PendingCacheHit(t *testing.T) {
|
||||
const (
|
||||
rpcPath = "/gFoo/method"
|
||||
wantKeyMapStr = "k1=v1"
|
||||
)
|
||||
kbm, err := initKeyBuilderMap()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create keyBuilderMap: %v", err)
|
||||
}
|
||||
md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"})
|
||||
wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
defaultPickExists bool
|
||||
}{
|
||||
{
|
||||
desc: "default pick exists",
|
||||
defaultPickExists: true,
|
||||
},
|
||||
{
|
||||
desc: "default pick does not exists",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
rlsCh := testutils.NewChannel()
|
||||
p := rlsPicker{
|
||||
kbm: kbm,
|
||||
// Cache lookup fails, pending entry exists.
|
||||
readCache: func(key cache.Key) (*cache.Entry, bool) {
|
||||
if !cmp.Equal(key, wantKey) {
|
||||
t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey)
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
// Never throttle. We do not expect an RLS request to be sent out anyways.
|
||||
shouldThrottle: func() bool { return false },
|
||||
startRLS: func(_ string, _ keys.KeyMap) {
|
||||
rlsCh.Send(nil)
|
||||
},
|
||||
}
|
||||
if test.defaultPickExists {
|
||||
p.defaultPick = func(info balancer.PickInfo) (balancer.PickResult, error) {
|
||||
// We do not expect the default picker to be invoked at all.
|
||||
// So, if we get here, the test will fail, because it
|
||||
// expects the pick to be queued.
|
||||
return balancer.PickResult{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := p.Pick(balancer.PickInfo{
|
||||
FullMethodName: rpcPath,
|
||||
Ctx: metadata.NewOutgoingContext(context.Background(), md),
|
||||
}); err != balancer.ErrNoSubConnAvailable {
|
||||
t.Fatalf("Pick() returned error {%v}, want {%v}", err, balancer.ErrNoSubConnAvailable)
|
||||
}
|
||||
|
||||
// Make sure that no RLS request was sent out.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
if _, err := rlsCh.Receive(ctx); err != context.DeadlineExceeded {
|
||||
t.Fatalf("RLS request sent out when pending entry exists")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPick_DataCacheHit_PendingCacheMiss verifies different Pick scenarios
|
||||
// where the entry is found in the data cache, and there is no entry in the
|
||||
// pending cache. This includes cases where the entry in the data cache is
|
||||
// stale, expired or in backoff.
|
||||
func TestPick_DataCacheHit_PendingCacheMiss(t *testing.T) {
|
||||
const (
|
||||
rpcPath = "/gFoo/method"
|
||||
wantKeyMapStr = "k1=v1"
|
||||
)
|
||||
kbm, err := initKeyBuilderMap()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create keyBuilderMap: %v", err)
|
||||
}
|
||||
md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"})
|
||||
wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr}
|
||||
rlsLastErr := errors.New("last RLS request failed")
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
// The cache entry, as returned by the overridden readCache hook.
|
||||
cacheEntry *cache.Entry
|
||||
// Whether or not a default target is configured.
|
||||
defaultPickExists bool
|
||||
// Whether or not the RLS request should be throttled.
|
||||
throttle bool
|
||||
// Whether or not the test is expected to make a new RLS request.
|
||||
wantRLSRequest bool
|
||||
// Whether or not the rlsPicker should delegate to the child picker.
|
||||
wantChildPick bool
|
||||
// Whether or not the rlsPicker should delegate to the default picker.
|
||||
wantDefaultPick bool
|
||||
// Expected error returned by the rlsPicker under test.
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
desc: "valid entry",
|
||||
cacheEntry: &cache.Entry{
|
||||
ExpiryTime: time.Now().Add(defaultTestMaxAge),
|
||||
StaleTime: time.Now().Add(defaultTestMaxAge),
|
||||
},
|
||||
wantChildPick: true,
|
||||
},
|
||||
{
|
||||
desc: "entryStale_requestThrottled",
|
||||
cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)},
|
||||
throttle: true,
|
||||
wantChildPick: true,
|
||||
},
|
||||
{
|
||||
desc: "entryStale_requestNotThrottled",
|
||||
cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)},
|
||||
wantRLSRequest: true,
|
||||
wantChildPick: true,
|
||||
},
|
||||
{
|
||||
desc: "entryExpired_requestThrottled_defaultPickExists",
|
||||
cacheEntry: &cache.Entry{},
|
||||
throttle: true,
|
||||
defaultPickExists: true,
|
||||
wantDefaultPick: true,
|
||||
},
|
||||
{
|
||||
desc: "entryExpired_requestThrottled_defaultPickNotExists",
|
||||
cacheEntry: &cache.Entry{},
|
||||
throttle: true,
|
||||
wantErr: errRLSThrottled,
|
||||
},
|
||||
{
|
||||
desc: "entryExpired_requestNotThrottled",
|
||||
cacheEntry: &cache.Entry{},
|
||||
wantRLSRequest: true,
|
||||
wantErr: balancer.ErrNoSubConnAvailable,
|
||||
},
|
||||
{
|
||||
desc: "entryExpired_backoffNotExpired_defaultPickExists",
|
||||
cacheEntry: &cache.Entry{
|
||||
BackoffTime: time.Now().Add(defaultTestMaxAge),
|
||||
CallStatus: rlsLastErr,
|
||||
},
|
||||
defaultPickExists: true,
|
||||
},
|
||||
{
|
||||
desc: "entryExpired_backoffNotExpired_defaultPickNotExists",
|
||||
cacheEntry: &cache.Entry{
|
||||
BackoffTime: time.Now().Add(defaultTestMaxAge),
|
||||
CallStatus: rlsLastErr,
|
||||
},
|
||||
wantErr: rlsLastErr,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
rlsCh := testutils.NewChannel()
|
||||
childPicker := newFakePicker()
|
||||
defaultPicker := newFakePicker()
|
||||
|
||||
p := rlsPicker{
|
||||
kbm: kbm,
|
||||
readCache: func(key cache.Key) (*cache.Entry, bool) {
|
||||
if !cmp.Equal(key, wantKey) {
|
||||
t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey)
|
||||
}
|
||||
test.cacheEntry.ChildPicker = childPicker
|
||||
return test.cacheEntry, false
|
||||
},
|
||||
shouldThrottle: func() bool { return test.throttle },
|
||||
startRLS: func(path string, km keys.KeyMap) {
|
||||
if !test.wantRLSRequest {
|
||||
rlsCh.Send(errors.New("RLS request attempted when none was expected"))
|
||||
return
|
||||
}
|
||||
if path != rpcPath {
|
||||
rlsCh.Send(fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath))
|
||||
return
|
||||
}
|
||||
if km.Str != wantKeyMapStr {
|
||||
rlsCh.Send(fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr))
|
||||
return
|
||||
}
|
||||
rlsCh.Send(nil)
|
||||
},
|
||||
}
|
||||
if test.defaultPickExists {
|
||||
p.defaultPick = defaultPicker.Pick
|
||||
}
|
||||
|
||||
gotResult, err := p.Pick(balancer.PickInfo{
|
||||
FullMethodName: rpcPath,
|
||||
Ctx: metadata.NewOutgoingContext(context.Background(), md),
|
||||
})
|
||||
if err != test.wantErr {
|
||||
t.Fatalf("Pick() returned error {%v}, want {%v}", err, test.wantErr)
|
||||
}
|
||||
// If the test specified that a new RLS request should be made,
|
||||
// verify it.
|
||||
if test.wantRLSRequest {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
if rlsErr, err := rlsCh.Receive(ctx); err != nil || rlsErr != nil {
|
||||
t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err)
|
||||
}
|
||||
}
|
||||
if test.wantErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// We get here only for cases where we expect the pick to be
|
||||
// delegated to the child picker or the default picker.
|
||||
if test.wantChildPick {
|
||||
if err := verifySubConn(gotResult.SubConn, childPicker.id); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
}
|
||||
if test.wantDefaultPick {
|
||||
if err := verifySubConn(gotResult.SubConn, defaultPicker.id); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPick_DataCacheHit_PendingCacheHit verifies different Pick scenarios where
|
||||
// the entry is found both in the data cache and in the pending cache. This
|
||||
// mostly verifies cases where the entry is stale, but there is already a
|
||||
// pending RLS request, so no new request should be sent out.
|
||||
func TestPick_DataCacheHit_PendingCacheHit(t *testing.T) {
|
||||
const (
|
||||
rpcPath = "/gFoo/method"
|
||||
wantKeyMapStr = "k1=v1"
|
||||
)
|
||||
kbm, err := initKeyBuilderMap()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create keyBuilderMap: %v", err)
|
||||
}
|
||||
md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"})
|
||||
wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
// The cache entry, as returned by the overridden readCache hook.
|
||||
cacheEntry *cache.Entry
|
||||
// Whether or not a default target is configured.
|
||||
defaultPickExists bool
|
||||
// Expected error returned by the rlsPicker under test.
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
desc: "stale entry",
|
||||
cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)},
|
||||
},
|
||||
{
|
||||
desc: "stale entry with default picker",
|
||||
cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)},
|
||||
defaultPickExists: true,
|
||||
},
|
||||
{
|
||||
desc: "entryExpired_defaultPickExists",
|
||||
cacheEntry: &cache.Entry{},
|
||||
defaultPickExists: true,
|
||||
wantErr: balancer.ErrNoSubConnAvailable,
|
||||
},
|
||||
{
|
||||
desc: "entryExpired_defaultPickNotExists",
|
||||
cacheEntry: &cache.Entry{},
|
||||
wantErr: balancer.ErrNoSubConnAvailable,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
rlsCh := testutils.NewChannel()
|
||||
childPicker := newFakePicker()
|
||||
|
||||
p := rlsPicker{
|
||||
kbm: kbm,
|
||||
readCache: func(key cache.Key) (*cache.Entry, bool) {
|
||||
if !cmp.Equal(key, wantKey) {
|
||||
t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey)
|
||||
}
|
||||
test.cacheEntry.ChildPicker = childPicker
|
||||
return test.cacheEntry, true
|
||||
},
|
||||
// Never throttle. We do not expect an RLS request to be sent out anyways.
|
||||
shouldThrottle: func() bool { return false },
|
||||
startRLS: func(path string, km keys.KeyMap) {
|
||||
rlsCh.Send(nil)
|
||||
},
|
||||
}
|
||||
if test.defaultPickExists {
|
||||
p.defaultPick = func(info balancer.PickInfo) (balancer.PickResult, error) {
|
||||
// We do not expect the default picker to be invoked at all.
|
||||
// So, if we get here, we return an error.
|
||||
return balancer.PickResult{}, errors.New("default picker invoked when expecting a child pick")
|
||||
}
|
||||
}
|
||||
|
||||
gotResult, err := p.Pick(balancer.PickInfo{
|
||||
FullMethodName: rpcPath,
|
||||
Ctx: metadata.NewOutgoingContext(context.Background(), md),
|
||||
})
|
||||
if err != test.wantErr {
|
||||
t.Fatalf("Pick() returned error {%v}, want {%v}", err, test.wantErr)
|
||||
}
|
||||
// Make sure that no RLS request was sent out.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
if _, err := rlsCh.Receive(ctx); err != context.DeadlineExceeded {
|
||||
t.Fatalf("RLS request sent out when pending entry exists")
|
||||
}
|
||||
if test.wantErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// We get here only for cases where we expect the pick to be
|
||||
// delegated to the child picker.
|
||||
if err := verifySubConn(gotResult.SubConn, childPicker.id); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2021 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
// Package e2e contains utilities for end-to-end RouteLookupService tests.
|
||||
package e2e
|
|
@ -0,0 +1,131 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2021 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/internal/grpcsync"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/serviceconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
// RLSChildPolicyTargetNameField is a top-level field name to add to the child
|
||||
// policy's config, whose value is set to the target for the child policy.
|
||||
RLSChildPolicyTargetNameField = "Backend"
|
||||
// RLSChildPolicyBadTarget is a value which is considered a bad target by the
|
||||
// child policy. This is useful to test bad child policy configuration.
|
||||
RLSChildPolicyBadTarget = "bad-target"
|
||||
)
|
||||
|
||||
// ErrParseConfigBadTarget is the error returned from ParseConfig when the
|
||||
// backend field is set to RLSChildPolicyBadTarget.
|
||||
var ErrParseConfigBadTarget = errors.New("backend field set to RLSChildPolicyBadTarget")
|
||||
|
||||
// BalancerFuncs is a set of callbacks which get invoked when the corresponding
|
||||
// method on the child policy is invoked.
|
||||
type BalancerFuncs struct {
|
||||
UpdateClientConnState func(cfg *RLSChildPolicyConfig) error
|
||||
Close func()
|
||||
}
|
||||
|
||||
// RegisterRLSChildPolicy registers a balancer builder with the given name, to
|
||||
// be used as a child policy for the RLS LB policy.
|
||||
//
|
||||
// The child policy uses a pickfirst balancer under the hood to send all traffic
|
||||
// to the single backend specified by the `RLSChildPolicyTargetNameField` field
|
||||
// in its configuration which looks like: {"Backend": "Backend-address"}.
|
||||
func RegisterRLSChildPolicy(name string, bf *BalancerFuncs) {
|
||||
balancer.Register(bb{name: name, bf: bf})
|
||||
}
|
||||
|
||||
type bb struct {
|
||||
name string
|
||||
bf *BalancerFuncs
|
||||
}
|
||||
|
||||
func (bb bb) Name() string { return bb.name }
|
||||
|
||||
func (bb bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
|
||||
pf := balancer.Get(grpc.PickFirstBalancerName)
|
||||
b := &bal{
|
||||
Balancer: pf.Build(cc, opts),
|
||||
bf: bb.bf,
|
||||
done: grpcsync.NewEvent(),
|
||||
}
|
||||
go b.run()
|
||||
return b
|
||||
}
|
||||
|
||||
func (bb bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
|
||||
cfg := &RLSChildPolicyConfig{}
|
||||
if err := json.Unmarshal(c, cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.Backend == RLSChildPolicyBadTarget {
|
||||
return nil, ErrParseConfigBadTarget
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
type bal struct {
|
||||
balancer.Balancer
|
||||
bf *BalancerFuncs
|
||||
done *grpcsync.Event
|
||||
}
|
||||
|
||||
// RLSChildPolicyConfig is the LB config for the test child policy.
|
||||
type RLSChildPolicyConfig struct {
|
||||
serviceconfig.LoadBalancingConfig
|
||||
Backend string // The target for which this child policy was created.
|
||||
Random string // A random field to test child policy config changes.
|
||||
}
|
||||
|
||||
func (b *bal) UpdateClientConnState(c balancer.ClientConnState) error {
|
||||
cfg, ok := c.BalancerConfig.(*RLSChildPolicyConfig)
|
||||
if !ok {
|
||||
return fmt.Errorf("received balancer config of type %T, want %T", c.BalancerConfig, &RLSChildPolicyConfig{})
|
||||
}
|
||||
if b.bf != nil && b.bf.UpdateClientConnState != nil {
|
||||
b.bf.UpdateClientConnState(cfg)
|
||||
}
|
||||
return b.Balancer.UpdateClientConnState(balancer.ClientConnState{
|
||||
ResolverState: resolver.State{Addresses: []resolver.Address{{Addr: cfg.Backend}}},
|
||||
})
|
||||
}
|
||||
|
||||
func (b *bal) Close() {
|
||||
b.Balancer.Close()
|
||||
if b.bf != nil && b.bf.Close != nil {
|
||||
b.bf.Close()
|
||||
}
|
||||
b.done.Fire()
|
||||
}
|
||||
|
||||
// run is a dummy goroutine to make sure that child policies are closed at the
|
||||
// end of tests. If they are not closed, these goroutines will be picked up by
|
||||
// the leakcheker and tests will fail.
|
||||
func (b *bal) run() {
|
||||
<-b.done.Done()
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
rlsgrpc "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
|
||||
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// RouteLookupResponse wraps an RLS response and the associated error to be sent
|
||||
// to a client when the RouteLookup RPC is invoked.
|
||||
type RouteLookupResponse struct {
|
||||
Resp *rlspb.RouteLookupResponse
|
||||
Err error
|
||||
}
|
||||
|
||||
// FakeRouteLookupServer is a fake implementation of the RouteLookupService.
|
||||
//
|
||||
// It is safe for concurrent use.
|
||||
type FakeRouteLookupServer struct {
|
||||
rlsgrpc.UnimplementedRouteLookupServiceServer
|
||||
Address string
|
||||
|
||||
mu sync.Mutex
|
||||
respCb func(context.Context, *rlspb.RouteLookupRequest) *RouteLookupResponse
|
||||
reqCb func(*rlspb.RouteLookupRequest)
|
||||
}
|
||||
|
||||
// StartFakeRouteLookupServer starts a fake RLS server listening for requests on
|
||||
// lis. If lis is nil, it creates a new listener on a random local port. The
|
||||
// returned cancel function should be invoked by the caller upon completion of
|
||||
// the test.
|
||||
func StartFakeRouteLookupServer(t *testing.T, lis net.Listener, opts ...grpc.ServerOption) (*FakeRouteLookupServer, func()) {
|
||||
t.Helper()
|
||||
|
||||
if lis == nil {
|
||||
var err error
|
||||
lis, err = testutils.LocalTCPListener()
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen() failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
s := &FakeRouteLookupServer{Address: lis.Addr().String()}
|
||||
server := grpc.NewServer(opts...)
|
||||
rlsgrpc.RegisterRouteLookupServiceServer(server, s)
|
||||
go server.Serve(lis)
|
||||
return s, func() { server.Stop() }
|
||||
}
|
||||
|
||||
// RouteLookup implements the RouteLookupService.
|
||||
func (s *FakeRouteLookupServer) RouteLookup(ctx context.Context, req *rlspb.RouteLookupRequest) (*rlspb.RouteLookupResponse, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.reqCb != nil {
|
||||
s.reqCb(req)
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, status.Error(codes.DeadlineExceeded, err.Error())
|
||||
}
|
||||
if s.respCb == nil {
|
||||
return &rlspb.RouteLookupResponse{}, nil
|
||||
}
|
||||
resp := s.respCb(ctx, req)
|
||||
return resp.Resp, resp.Err
|
||||
}
|
||||
|
||||
// SetResponseCallback sets a callback to be invoked on every RLS request. If
|
||||
// this callback is set, the response returned by the fake server depends on the
|
||||
// value returned by the callback. If this callback is not set, the fake server
|
||||
// responds with an empty response.
|
||||
func (s *FakeRouteLookupServer) SetResponseCallback(f func(context.Context, *rlspb.RouteLookupRequest) *RouteLookupResponse) {
|
||||
s.mu.Lock()
|
||||
s.respCb = f
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetRequestCallback sets a callback to be invoked on every RLS request. The
|
||||
// callback is given the incoming request, and tests can use this to verify that
|
||||
// the request matches its expectations.
|
||||
func (s *FakeRouteLookupServer) SetRequestCallback(f func(*rlspb.RouteLookupRequest)) {
|
||||
s.mu.Lock()
|
||||
s.reqCb = f
|
||||
s.mu.Unlock()
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2021 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/balancer"
|
||||
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
|
||||
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
|
||||
"google.golang.org/grpc/serviceconfig"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
// RLSConfig is a utility type to build service config for the RLS LB policy.
|
||||
type RLSConfig struct {
|
||||
RouteLookupConfig *rlspb.RouteLookupConfig
|
||||
ChildPolicy *internalserviceconfig.BalancerConfig
|
||||
ChildPolicyConfigTargetFieldName string
|
||||
}
|
||||
|
||||
// ServiceConfigJSON generates service config with a load balancing config
|
||||
// corresponding to the RLS LB policy.
|
||||
func (c *RLSConfig) ServiceConfigJSON() (string, error) {
|
||||
m := protojson.MarshalOptions{
|
||||
Multiline: true,
|
||||
Indent: " ",
|
||||
UseProtoNames: true,
|
||||
}
|
||||
routeLookupCfg, err := m.Marshal(c.RouteLookupConfig)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
childPolicy, err := c.ChildPolicy.MarshalJSON()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`
|
||||
{
|
||||
"loadBalancingConfig": [
|
||||
{
|
||||
"rls_experimental": {
|
||||
"routeLookupConfig": %s,
|
||||
"childPolicy": %s,
|
||||
"childPolicyConfigTargetFieldName": %q
|
||||
}
|
||||
}
|
||||
]
|
||||
}`, string(routeLookupCfg), string(childPolicy), c.ChildPolicyConfigTargetFieldName), nil
|
||||
}
|
||||
|
||||
// LoadBalancingConfig generates load balancing config which can used as part of
|
||||
// a ClientConnState update to the RLS LB policy.
|
||||
func (c *RLSConfig) LoadBalancingConfig() (serviceconfig.LoadBalancingConfig, error) {
|
||||
m := protojson.MarshalOptions{
|
||||
Multiline: true,
|
||||
Indent: " ",
|
||||
UseProtoNames: true,
|
||||
}
|
||||
routeLookupCfg, err := m.Marshal(c.RouteLookupConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
childPolicy, err := c.ChildPolicy.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lbConfigJSON := fmt.Sprintf(`
|
||||
{
|
||||
"routeLookupConfig": %s,
|
||||
"childPolicy": %s,
|
||||
"childPolicyConfigTargetFieldName": %q
|
||||
}`, string(routeLookupCfg), string(childPolicy), c.ChildPolicyConfigTargetFieldName)
|
||||
|
||||
builder := balancer.Get("rls_experimental")
|
||||
if builder == nil {
|
||||
return nil, errors.New("balancer builder not found for RLS LB policy")
|
||||
}
|
||||
parser := builder.(balancer.ConfigParser)
|
||||
return parser.ParseConfig([]byte(lbConfigJSON))
|
||||
}
|
|
@ -80,6 +80,14 @@ func (ss *StubServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallSer
|
|||
|
||||
// Start starts the server and creates a client connected to it.
|
||||
func (ss *StubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption) error {
|
||||
if err := ss.StartServer(sopts...); err != nil {
|
||||
return err
|
||||
}
|
||||
return ss.StartClient(dopts...)
|
||||
}
|
||||
|
||||
// StartServer only starts the server. It does not create a client to it.
|
||||
func (ss *StubServer) StartServer(sopts ...grpc.ServerOption) error {
|
||||
if ss.Network == "" {
|
||||
ss.Network = "tcp"
|
||||
}
|
||||
|
@ -102,7 +110,12 @@ func (ss *StubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption)
|
|||
go s.Serve(lis)
|
||||
ss.cleanups = append(ss.cleanups, s.Stop)
|
||||
ss.S = s
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartClient creates a client connected to this service that the test may use.
|
||||
// The newly created client will be available in the Client field of StubServer.
|
||||
func (ss *StubServer) StartClient(dopts ...grpc.DialOption) error {
|
||||
opts := append([]grpc.DialOption{grpc.WithInsecure()}, dopts...)
|
||||
if ss.R != nil {
|
||||
ss.Target = ss.R.Scheme() + ":///" + ss.Address
|
||||
|
|
|
@ -83,12 +83,11 @@ func (l *RestartableListener) Addr() net.Addr {
|
|||
func (l *RestartableListener) Stop() {
|
||||
l.mu.Lock()
|
||||
l.stopped = true
|
||||
tmp := l.conns
|
||||
l.conns = nil
|
||||
l.mu.Unlock()
|
||||
for _, conn := range tmp {
|
||||
for _, conn := range l.conns {
|
||||
conn.Close()
|
||||
}
|
||||
l.conns = nil
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
// Restart gets a previously stopped listener to start accepting connections.
|
||||
|
|
Loading…
Reference in New Issue