rls: control channel implementation (#5046)

This commit is contained in:
Easwar Swaminathan 2021-12-15 09:37:05 -08:00 committed by GitHub
parent 7c8a9321b9
commit 50f82701b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1389 additions and 1427 deletions

View File

@ -19,183 +19,36 @@
package rls
import (
"sync"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/grpcsync"
)
var (
_ balancer.Balancer = (*rlsBalancer)(nil)
// For overriding in tests.
newRLSClientFunc = newRLSClient
logger = grpclog.Component("rls")
logger = grpclog.Component("rls")
)
// rlsBalancer implements the RLS LB policy.
type rlsBalancer struct {
done *grpcsync.Event
cc balancer.ClientConn
opts balancer.BuildOptions
type rlsBalancer struct{}
// Mutex protects all the state maintained by the LB policy.
// TODO(easwars): Once we add the cache, we will also have another lock for
// the cache alone.
mu sync.Mutex
lbCfg *lbConfig // Most recently received service config.
rlsCC *grpc.ClientConn // ClientConn to the RLS server.
rlsC *rlsClient // RLS client wrapper.
ccUpdateCh chan *balancer.ClientConnState
}
// run is a long running goroutine which handles all the updates that the
// balancer wishes to handle. The appropriate updateHandler will push the update
// on to a channel that this goroutine will select on, thereby the handling of
// the update will happen asynchronously.
func (lb *rlsBalancer) run() {
for {
// TODO(easwars): Handle other updates like subConn state changes, RLS
// responses from the server etc.
select {
case u := <-lb.ccUpdateCh:
lb.handleClientConnUpdate(u)
case <-lb.done.Done():
return
}
}
}
// handleClientConnUpdate handles updates to the service config.
// If the RLS server name or the RLS RPC timeout changes, it updates the control
// channel accordingly.
// TODO(easwars): Handle updates to other fields in the service config.
func (lb *rlsBalancer) handleClientConnUpdate(ccs *balancer.ClientConnState) {
logger.Infof("rls: service config: %+v", ccs.BalancerConfig)
lb.mu.Lock()
defer lb.mu.Unlock()
if lb.done.HasFired() {
logger.Warning("rls: received service config after balancer close")
return
}
newCfg := ccs.BalancerConfig.(*lbConfig)
if lb.lbCfg.Equal(newCfg) {
logger.Info("rls: new service config matches existing config")
return
}
lb.updateControlChannel(newCfg)
lb.lbCfg = newCfg
}
// UpdateClientConnState pushes the received ClientConnState update on the
// update channel which will be processed asynchronously by the run goroutine.
// Implements balancer.Balancer interface.
func (lb *rlsBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
select {
case lb.ccUpdateCh <- &ccs:
case <-lb.done.Done():
}
logger.Fatal("rls: UpdateClientConnState is not yet unimplemented")
return nil
}
// ResolverErr implements balancer.Balancer interface.
func (lb *rlsBalancer) ResolverError(error) {
// ResolverError is called by gRPC when the name resolver reports an error.
// TODO(easwars): How do we handle this?
logger.Fatal("rls: ResolverError is not yet unimplemented")
}
// UpdateSubConnState implements balancer.Balancer interface.
func (lb *rlsBalancer) UpdateSubConnState(_ balancer.SubConn, _ balancer.SubConnState) {
logger.Fatal("rls: UpdateSubConnState is not yet implemented")
}
// Cleans up the resources allocated by the LB policy including the clientConn
// to the RLS server.
// Implements balancer.Balancer.
func (lb *rlsBalancer) Close() {
lb.mu.Lock()
defer lb.mu.Unlock()
lb.done.Fire()
if lb.rlsCC != nil {
lb.rlsCC.Close()
}
logger.Fatal("rls: Close is not yet implemented")
}
func (lb *rlsBalancer) ExitIdle() {
// TODO: are we 100% sure this should be a nop?
}
// updateControlChannel updates the RLS client if required.
// Caller must hold lb.mu.
func (lb *rlsBalancer) updateControlChannel(newCfg *lbConfig) {
oldCfg := lb.lbCfg
if newCfg.lookupService == oldCfg.lookupService && newCfg.lookupServiceTimeout == oldCfg.lookupServiceTimeout {
return
}
// Use RPC timeout from new config, if different from existing one.
timeout := oldCfg.lookupServiceTimeout
if timeout != newCfg.lookupServiceTimeout {
timeout = newCfg.lookupServiceTimeout
}
if newCfg.lookupService == oldCfg.lookupService {
// This is the case where only the timeout has changed. We will continue
// to use the existing clientConn. but will create a new rlsClient with
// the new timeout.
lb.rlsC = newRLSClientFunc(lb.rlsCC, lb.opts.Target.Endpoint, timeout)
return
}
// This is the case where the RLS server name has changed. We need to create
// a new clientConn and close the old one.
var dopts []grpc.DialOption
if dialer := lb.opts.Dialer; dialer != nil {
dopts = append(dopts, grpc.WithContextDialer(dialer))
}
dopts = append(dopts, dialCreds(lb.opts))
cc, err := grpc.Dial(newCfg.lookupService, dopts...)
if err != nil {
logger.Errorf("rls: dialRLS(%s, %v): %v", newCfg.lookupService, lb.opts, err)
// An error from a non-blocking dial indicates something serious. We
// should continue to use the old control channel if one exists, and
// return so that the rest of the config updates can be processes.
return
}
if lb.rlsCC != nil {
lb.rlsCC.Close()
}
lb.rlsCC = cc
lb.rlsC = newRLSClientFunc(cc, lb.opts.Target.Endpoint, timeout)
}
func dialCreds(opts balancer.BuildOptions) grpc.DialOption {
// The control channel should use the same authority as that of the parent
// channel. This ensures that the identify of the RLS server and that of the
// backend is the same, so if the RLS config is injected by an attacker, it
// cannot cause leakage of private information contained in headers set by
// the application.
server := opts.Target.Authority
switch {
case opts.DialCreds != nil:
if err := opts.DialCreds.OverrideServerName(server); err != nil {
logger.Warningf("rls: OverrideServerName(%s) = (%v), using Insecure", server, err)
return grpc.WithInsecure()
}
return grpc.WithTransportCredentials(opts.DialCreds)
case opts.CredsBundle != nil:
return grpc.WithTransportCredentials(opts.CredsBundle.TransportCredentials())
default:
logger.Warning("rls: no credentials available, using Insecure")
return grpc.WithInsecure()
}
logger.Fatal("rls: ExitIdle is not yet implemented")
}

View File

@ -1,238 +0,0 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rls
import (
"context"
"net"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/testdata"
)
const defaultTestTimeout = 1 * time.Second
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
type listenerWrapper struct {
net.Listener
connCh *testutils.Channel
}
// Accept waits for and returns the next connection to the listener.
func (l *listenerWrapper) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
l.connCh.Send(c)
return c, nil
}
func setupwithListener(t *testing.T, opts ...grpc.ServerOption) (*fakeserver.Server, *listenerWrapper, func()) {
t.Helper()
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("net.Listen(tcp, localhost:0): %v", err)
}
lw := &listenerWrapper{
Listener: l,
connCh: testutils.NewChannel(),
}
server, cleanup, err := fakeserver.Start(lw, opts...)
if err != nil {
t.Fatalf("fakeserver.Start(): %v", err)
}
t.Logf("Fake RLS server started at %s ...", server.Address)
return server, lw, cleanup
}
type testBalancerCC struct {
balancer.ClientConn
}
// TestUpdateControlChannelFirstConfig tests the scenario where the LB policy
// receives its first service config and verifies that a control channel to the
// RLS server specified in the serviceConfig is established.
func (s) TestUpdateControlChannelFirstConfig(t *testing.T) {
server, lis, cleanup := setupwithListener(t)
defer cleanup()
bb := balancer.Get(rlsBalancerName)
if bb == nil {
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
}
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{})
defer rlsB.Close()
t.Log("Built RLS LB policy ...")
lbCfg := &lbConfig{lookupService: server.Address}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := lis.connCh.Receive(ctx); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
}
// TODO: Verify channel connectivity state once control channel connectivity
// state monitoring is in place.
// TODO: Verify RLS RPC can be made once we integrate with the picker.
}
// TestUpdateControlChannelSwitch tests the scenario where a control channel
// exists and the LB policy receives a new serviceConfig with a different RLS
// server name. Verifies that the new control channel is created and the old one
// is closed (the leakchecker takes care of this).
func (s) TestUpdateControlChannelSwitch(t *testing.T) {
server1, lis1, cleanup1 := setupwithListener(t)
defer cleanup1()
server2, lis2, cleanup2 := setupwithListener(t)
defer cleanup2()
bb := balancer.Get(rlsBalancerName)
if bb == nil {
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
}
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{})
defer rlsB.Close()
t.Log("Built RLS LB policy ...")
lbCfg := &lbConfig{lookupService: server1.Address}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := lis1.connCh.Receive(ctx); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
}
lbCfg = &lbConfig{lookupService: server2.Address}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis2.connCh.Receive(ctx); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
}
// TODO: Verify channel connectivity state once control channel connectivity
// state monitoring is in place.
// TODO: Verify RLS RPC can be made once we integrate with the picker.
}
// TestUpdateControlChannelTimeout tests the scenario where the LB policy
// receives a service config update with a different lookupServiceTimeout, but
// the lookupService itself remains unchanged. It verifies that the LB policy
// does not create a new control channel in this case.
func (s) TestUpdateControlChannelTimeout(t *testing.T) {
server, lis, cleanup := setupwithListener(t)
defer cleanup()
bb := balancer.Get(rlsBalancerName)
if bb == nil {
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
}
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{})
defer rlsB.Close()
t.Log("Built RLS LB policy ...")
lbCfg := &lbConfig{lookupService: server.Address, lookupServiceTimeout: 1 * time.Second}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := lis.connCh.Receive(ctx); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
}
lbCfg = &lbConfig{lookupService: server.Address, lookupServiceTimeout: 2 * time.Second}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis.connCh.Receive(ctx); err != context.DeadlineExceeded {
t.Fatal("LB policy created new control channel when only lookupServiceTimeout changed")
}
// TODO: Verify channel connectivity state once control channel connectivity
// state monitoring is in place.
// TODO: Verify RLS RPC can be made once we integrate with the picker.
}
// TestUpdateControlChannelWithCreds tests the scenario where the control
// channel is to established with credentials from the parent channel.
func (s) TestUpdateControlChannelWithCreds(t *testing.T) {
sCreds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
if err != nil {
t.Fatalf("credentials.NewServerTLSFromFile(server1.pem, server1.key) = %v", err)
}
cCreds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "")
if err != nil {
t.Fatalf("credentials.NewClientTLSFromFile(ca.pem) = %v", err)
}
server, lis, cleanup := setupwithListener(t, grpc.Creds(sCreds))
defer cleanup()
bb := balancer.Get(rlsBalancerName)
if bb == nil {
t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName)
}
rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{
DialCreds: cCreds,
})
defer rlsB.Close()
t.Log("Built RLS LB policy ...")
lbCfg := &lbConfig{lookupService: server.Address}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := lis.connCh.Receive(ctx); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel")
}
// TODO: Verify channel connectivity state once control channel connectivity
// state monitoring is in place.
// TODO: Verify RLS RPC can be made once we integrate with the picker.
}

View File

@ -21,10 +21,9 @@ package rls
import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/internal/grpcsync"
)
const rlsBalancerName = "rls"
const rlsBalancerName = "rls_experimental"
func init() {
balancer.Register(&rlsBB{})
@ -41,13 +40,6 @@ func (*rlsBB) Name() string {
}
func (*rlsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
lb := &rlsBalancer{
done: grpcsync.NewEvent(),
cc: cc,
opts: opts,
lbCfg: &lbConfig{},
ccUpdateCh: make(chan *balancer.ClientConnState),
}
go lb.run()
return lb
// TODO(easwars): Fix this once the LB policy implementation is pulled in.
return &rlsBalancer{}
}

View File

@ -1,80 +0,0 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rls
import (
"context"
"time"
"google.golang.org/grpc"
rlsgrpc "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
)
// For gRPC services using RLS, the value of target_type in the
// RouteLookupServiceRequest will be set to this.
const grpcTargetType = "grpc"
// rlsClient is a simple wrapper around a RouteLookupService client which
// provides non-blocking semantics on top of a blocking unary RPC call.
//
// The RLS LB policy creates a new rlsClient object with the following values:
// * a grpc.ClientConn to the RLS server using appropriate credentials from the
// parent channel
// * dialTarget corresponding to the original user dial target, e.g.
// "firestore.googleapis.com".
//
// The RLS LB policy uses an adaptive throttler to perform client side
// throttling and asks this client to make an RPC call only after checking with
// the throttler.
type rlsClient struct {
stub rlsgrpc.RouteLookupServiceClient
// origDialTarget is the original dial target of the user and sent in each
// RouteLookup RPC made to the RLS server.
origDialTarget string
// rpcTimeout specifies the timeout for the RouteLookup RPC call. The LB
// policy receives this value in its service config.
rpcTimeout time.Duration
}
func newRLSClient(cc *grpc.ClientConn, dialTarget string, rpcTimeout time.Duration) *rlsClient {
return &rlsClient{
stub: rlsgrpc.NewRouteLookupServiceClient(cc),
origDialTarget: dialTarget,
rpcTimeout: rpcTimeout,
}
}
type lookupCallback func(targets []string, headerData string, err error)
// lookup starts a RouteLookup RPC in a separate goroutine and returns the
// results (and error, if any) in the provided callback.
func (c *rlsClient) lookup(keyMap map[string]string, cb lookupCallback) {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), c.rpcTimeout)
resp, err := c.stub.RouteLookup(ctx, &rlspb.RouteLookupRequest{
// TODO(easwars): Use extra_keys field to populate host, service and
// method keys.
TargetType: grpcTargetType,
KeyMap: keyMap,
})
cb(resp.GetTargets(), resp.GetHeaderData(), err)
cancel()
}()
}

View File

@ -1,178 +0,0 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rls
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver"
"google.golang.org/grpc/codes"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/status"
)
const (
defaultDialTarget = "dummy"
defaultRPCTimeout = 5 * time.Second
)
func setup(t *testing.T) (*fakeserver.Server, *grpc.ClientConn, func()) {
t.Helper()
server, sCleanup, err := fakeserver.Start(nil)
if err != nil {
t.Fatalf("Failed to start fake RLS server: %v", err)
}
cc, cCleanup, err := server.ClientConn()
if err != nil {
sCleanup()
t.Fatalf("Failed to get a ClientConn to the RLS server: %v", err)
}
return server, cc, func() {
sCleanup()
cCleanup()
}
}
// TestLookupFailure verifies the case where the RLS server returns an error.
func (s) TestLookupFailure(t *testing.T) {
server, cc, cleanup := setup(t)
defer cleanup()
// We setup the fake server to return an error.
server.ResponseChan <- fakeserver.Response{Err: errors.New("rls failure")}
rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout)
errCh := testutils.NewChannel()
rlsClient.lookup(nil, func(targets []string, headerData string, err error) {
if err == nil {
errCh.Send(errors.New("rlsClient.lookup() succeeded, should have failed"))
return
}
if len(targets) != 0 || headerData != "" {
errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (nil, \"\")", targets, headerData))
return
}
errCh.Send(nil)
})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if e, err := errCh.Receive(ctx); err != nil || e != nil {
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
}
}
// TestLookupDeadlineExceeded tests the case where the RPC deadline associated
// with the lookup expires.
func (s) TestLookupDeadlineExceeded(t *testing.T) {
_, cc, cleanup := setup(t)
defer cleanup()
// Give the Lookup RPC a small deadline, but don't setup the fake server to
// return anything. So the Lookup call will block and eventually expire.
rlsClient := newRLSClient(cc, defaultDialTarget, 100*time.Millisecond)
errCh := testutils.NewChannel()
rlsClient.lookup(nil, func(_ []string, _ string, err error) {
if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
errCh.Send(fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded))
return
}
errCh.Send(nil)
})
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if e, err := errCh.Receive(ctx); err != nil || e != nil {
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
}
}
// TestLookupSuccess verifies the successful Lookup API case.
func (s) TestLookupSuccess(t *testing.T) {
server, cc, cleanup := setup(t)
defer cleanup()
const wantHeaderData = "headerData"
rlsReqKeyMap := map[string]string{
"k1": "v1",
"k2": "v2",
}
wantLookupRequest := &rlspb.RouteLookupRequest{
// TODO(easwars): Use extra_keys field to populate host, service and
// method keys.
TargetType: "grpc",
KeyMap: rlsReqKeyMap,
}
wantRespTargets := []string{"us_east_1.firestore.googleapis.com"}
rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout)
errCh := testutils.NewChannel()
rlsClient.lookup(rlsReqKeyMap, func(targets []string, hd string, err error) {
if err != nil {
errCh.Send(fmt.Errorf("rlsClient.Lookup() failed: %v", err))
return
}
if !cmp.Equal(targets, wantRespTargets) || hd != wantHeaderData {
errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, hd, wantRespTargets, wantHeaderData))
return
}
errCh.Send(nil)
})
// Make sure that the fake server received the expected RouteLookupRequest
// proto.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
req, err := server.RequestChan.Receive(ctx)
if err != nil {
t.Fatalf("Timed out wile waiting for a RouteLookupRequest")
}
gotLookupRequest := req.(*rlspb.RouteLookupRequest)
if diff := cmp.Diff(wantLookupRequest, gotLookupRequest, cmp.Comparer(proto.Equal)); diff != "" {
t.Fatalf("RouteLookupRequest diff (-want, +got):\n%s", diff)
}
// We setup the fake server to return this response when it receives a
// request.
server.ResponseChan <- fakeserver.Response{
Resp: &rlspb.RouteLookupResponse{
Targets: wantRespTargets,
HeaderData: wantHeaderData,
},
}
if e, err := errCh.Receive(ctx); err != nil || e != nil {
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
}
}

View File

@ -61,7 +61,7 @@ func testEqual(a, b *lbConfig) bool {
childPolicyConfigEqual(a.childPolicyConfig, b.childPolicyConfig)
}
func TestParseConfig(t *testing.T) {
func (s) TestParseConfig(t *testing.T) {
childPolicyTargetFieldVal, _ := json.Marshal(dummyChildPolicyTarget)
tests := []struct {
desc string
@ -158,7 +158,7 @@ func TestParseConfig(t *testing.T) {
}
}
func TestParseConfigErrors(t *testing.T) {
func (s) TestParseConfigErrors(t *testing.T) {
tests := []struct {
desc string
input []byte

View File

@ -0,0 +1,206 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rls
import (
"context"
"fmt"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/rls/internal/adaptive"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
internalgrpclog "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/pretty"
rlsgrpc "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
)
var newAdaptiveThrottler = func() adaptiveThrottler { return adaptive.New() }
type adaptiveThrottler interface {
ShouldThrottle() bool
RegisterBackendResponse(throttled bool)
}
// controlChannel is a wrapper around the gRPC channel to the RLS server
// specified in the service config.
type controlChannel struct {
// rpcTimeout specifies the timeout for the RouteLookup RPC call. The LB
// policy receives this value in its service config.
rpcTimeout time.Duration
// backToReadyCh is the channel on which an update is pushed when the
// connectivity state changes from READY --> TRANSIENT_FAILURE --> READY.
backToReadyCh chan struct{}
// throttler in an adaptive throttling implementation used to avoid
// hammering the RLS service while it is overloaded or down.
throttler adaptiveThrottler
cc *grpc.ClientConn
client rlsgrpc.RouteLookupServiceClient
logger *internalgrpclog.PrefixLogger
}
func newControlChannel(rlsServerName string, rpcTimeout time.Duration, bOpts balancer.BuildOptions, backToReadyCh chan struct{}) (*controlChannel, error) {
ctrlCh := &controlChannel{
rpcTimeout: rpcTimeout,
backToReadyCh: backToReadyCh,
throttler: newAdaptiveThrottler(),
}
ctrlCh.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[rls-control-channel %p] ", ctrlCh))
dopts, err := ctrlCh.dialOpts(bOpts)
if err != nil {
return nil, err
}
ctrlCh.cc, err = grpc.Dial(rlsServerName, dopts...)
if err != nil {
return nil, err
}
ctrlCh.client = rlsgrpc.NewRouteLookupServiceClient(ctrlCh.cc)
ctrlCh.logger.Infof("Control channel created to RLS server at: %v", rlsServerName)
go ctrlCh.monitorConnectivityState()
return ctrlCh, nil
}
// dialOpts constructs the dial options for the control plane channel.
func (cc *controlChannel) dialOpts(bOpts balancer.BuildOptions) ([]grpc.DialOption, error) {
// The control plane channel will use the same authority as the parent
// channel for server authorization. This ensures that the identity of the
// RLS server and the identity of the backends is the same, so if the RLS
// config is injected by an attacker, it cannot cause leakage of private
// information contained in headers set by the application.
dopts := []grpc.DialOption{grpc.WithAuthority(bOpts.Authority)}
if bOpts.Dialer != nil {
dopts = append(dopts, grpc.WithContextDialer(bOpts.Dialer))
}
// The control channel will use the channel credentials from the parent
// channel, including any call creds associated with the channel creds.
var credsOpt grpc.DialOption
switch {
case bOpts.DialCreds != nil:
credsOpt = grpc.WithTransportCredentials(bOpts.DialCreds.Clone())
case bOpts.CredsBundle != nil:
// The "fallback" mode in google default credentials (which is the only
// type of credentials we expect to be used with RLS) uses TLS/ALTS
// creds for transport and uses the same call creds as that on the
// parent bundle.
bundle, err := bOpts.CredsBundle.NewWithMode(internal.CredsBundleModeFallback)
if err != nil {
return nil, err
}
credsOpt = grpc.WithCredentialsBundle(bundle)
default:
cc.logger.Warningf("no credentials available, using Insecure")
credsOpt = grpc.WithInsecure()
}
return append(dopts, credsOpt), nil
}
func (cc *controlChannel) monitorConnectivityState() {
cc.logger.Infof("Starting connectivity state monitoring goroutine")
// Since we use two mechanisms to deal with RLS server being down:
// - adaptive throttling for the channel as a whole
// - exponential backoff on a per-request basis
// we need a way to avoid double-penalizing requests by counting failures
// toward both mechanisms when the RLS server is unreachable.
//
// To accomplish this, we monitor the state of the control plane channel. If
// the state has been TRANSIENT_FAILURE since the last time it was in state
// READY, and it then transitions into state READY, we push on a channel
// which is being read by the LB policy.
//
// The LB the policy will iterate through the cache to reset the backoff
// timeouts in all cache entries. Specifically, this means that it will
// reset the backoff state and cancel the pending backoff timer. Note that
// when cancelling the backoff timer, just like when the backoff timer fires
// normally, a new picker is returned to the channel, to force it to
// re-process any wait-for-ready RPCs that may still be queued if we failed
// them while we were in backoff. However, we should optimize this case by
// returning only one new picker, regardless of how many backoff timers are
// cancelled.
// Using the background context is fine here since we check for the ClientConn
// entering SHUTDOWN and return early in that case.
ctx := context.Background()
first := true
for {
// Wait for the control channel to become READY.
for s := cc.cc.GetState(); s != connectivity.Ready; s = cc.cc.GetState() {
if s == connectivity.Shutdown {
return
}
cc.cc.WaitForStateChange(ctx, s)
}
cc.logger.Infof("Connectivity state is READY")
if !first {
cc.logger.Infof("Control channel back to READY")
cc.backToReadyCh <- struct{}{}
}
first = false
// Wait for the control channel to move out of READY.
cc.cc.WaitForStateChange(ctx, connectivity.Ready)
if cc.cc.GetState() == connectivity.Shutdown {
return
}
cc.logger.Infof("Connectivity state is %s", cc.cc.GetState())
}
}
func (cc *controlChannel) close() {
cc.logger.Infof("Closing control channel")
cc.cc.Close()
}
type lookupCallback func(targets []string, headerData string, err error)
// lookup starts a RouteLookup RPC in a separate goroutine and returns the
// results (and error, if any) in the provided callback.
//
// The returned boolean indicates whether the request was throttled by the
// client-side adaptive throttling algorithm in which case the provided callback
// will not be invoked.
func (cc *controlChannel) lookup(reqKeys map[string]string, reason rlspb.RouteLookupRequest_Reason, staleHeaders string, cb lookupCallback) (throttled bool) {
if cc.throttler.ShouldThrottle() {
cc.logger.Infof("RLS request throttled by client-side adaptive throttling")
return true
}
go func() {
req := &rlspb.RouteLookupRequest{
TargetType: "grpc",
KeyMap: reqKeys,
Reason: reason,
StaleHeaderData: staleHeaders,
}
cc.logger.Infof("Sending RLS request %+v", pretty.ToJSON(req))
ctx, cancel := context.WithTimeout(context.Background(), cc.rpcTimeout)
defer cancel()
resp, err := cc.client.RouteLookup(ctx, req)
cb(resp.GetTargets(), resp.GetHeaderData(), err)
}()
return false
}

View File

@ -0,0 +1,469 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rls
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/rls/internal/test/e2e"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/testdata"
"google.golang.org/protobuf/proto"
)
// TestControlChannelThrottled tests the case where the adaptive throttler
// indicates that the control channel needs to be throttled.
func (s) TestControlChannelThrottled(t *testing.T) {
// Start an RLS server and set the throttler to always throttle requests.
rlsServer, rlsReqCh := setupFakeRLSServer(t, nil)
overrideAdaptiveThrottler(t, alwaysThrottlingThrottler())
// Create a control channel to the fake RLS server.
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{}, nil)
if err != nil {
t.Fatalf("Failed to create control channel to RLS server: %v", err)
}
defer ctrlCh.close()
// Perform the lookup and expect the attempt to be throttled.
ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, nil)
select {
case <-rlsReqCh:
t.Fatal("RouteLookup RPC invoked when control channel is throtlled")
case <-time.After(defaultTestShortTimeout):
}
}
// TestLookupFailure tests the case where the RLS server responds with an error.
func (s) TestLookupFailure(t *testing.T) {
// Start an RLS server and set the throttler to never throttle requests.
rlsServer, _ := setupFakeRLSServer(t, nil)
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
// Setup the RLS server to respond with errors.
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *e2e.RouteLookupResponse {
return &e2e.RouteLookupResponse{Err: errors.New("rls failure")}
})
// Create a control channel to the fake RLS server.
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{}, nil)
if err != nil {
t.Fatalf("Failed to create control channel to RLS server: %v", err)
}
defer ctrlCh.close()
// Perform the lookup and expect the callback to be invoked with an error.
errCh := make(chan error, 1)
ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
if err == nil {
errCh <- errors.New("rlsClient.lookup() succeeded, should have failed")
return
}
errCh <- nil
})
select {
case <-time.After(defaultTestTimeout):
t.Fatal("timeout when waiting for lookup callback to be invoked")
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}
// TestLookupDeadlineExceeded tests the case where the RLS server does not
// respond within the configured rpc timeout.
func (s) TestLookupDeadlineExceeded(t *testing.T) {
// A unary interceptor which sleeps for long enough to cause lookup RPCs to
// exceed their deadline.
rlsReqCh := make(chan struct{}, 1)
interceptor := func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
rlsReqCh <- struct{}{}
time.Sleep(2 * defaultTestShortTimeout)
return handler(ctx, req)
}
// Start an RLS server and set the throttler to never throttle.
rlsServer, _ := setupFakeRLSServer(t, nil, grpc.UnaryInterceptor(interceptor))
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
// Create a control channel with a small deadline.
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestShortTimeout, balancer.BuildOptions{}, nil)
if err != nil {
t.Fatalf("Failed to create control channel to RLS server: %v", err)
}
defer ctrlCh.close()
// Perform the lookup and expect the callback to be invoked with an error.
errCh := make(chan error)
ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded)
return
}
errCh <- nil
})
select {
case <-time.After(defaultTestTimeout):
t.Fatal("timeout when waiting for lookup callback to be invoked")
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}
// testCredsBundle wraps a test call creds and real transport creds.
type testCredsBundle struct {
transportCreds credentials.TransportCredentials
callCreds credentials.PerRPCCredentials
}
func (f *testCredsBundle) TransportCredentials() credentials.TransportCredentials {
return f.transportCreds
}
func (f *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials {
return f.callCreds
}
func (f *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) {
if mode != internal.CredsBundleModeFallback {
return nil, fmt.Errorf("unsupported mode: %v", mode)
}
return &testCredsBundle{
transportCreds: f.transportCreds,
callCreds: f.callCreds,
}, nil
}
var (
// Call creds sent by the testPerRPCCredentials on the client, and verified
// by an interceptor on the server.
perRPCCredsData = map[string]string{
"test-key": "test-value",
"test-key-bin": string([]byte{1, 2, 3}),
}
)
type testPerRPCCredentials struct {
callCreds map[string]string
}
func (f *testPerRPCCredentials) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
return f.callCreds, nil
}
func (f *testPerRPCCredentials) RequireTransportSecurity() bool {
return true
}
// Unary server interceptor which validates if the RPC contains call credentials
// which match `perRPCCredsData
func callCredsValidatingServerInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.PermissionDenied, "didn't find metadata in context")
}
for k, want := range perRPCCredsData {
got, ok := md[k]
if !ok {
return ctx, status.Errorf(codes.PermissionDenied, "didn't find call creds key %v in context", k)
}
if got[0] != want {
return ctx, status.Errorf(codes.PermissionDenied, "for key %v, got value %v, want %v", k, got, want)
}
}
return handler(ctx, req)
}
// makeTLSCreds is a test helper which creates a TLS based transport credentials
// from files specified in the arguments.
func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials.TransportCredentials {
cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
if err != nil {
t.Fatalf("tls.LoadX509KeyPair(%q, %q) failed: %v", certPath, keyPath, err)
}
b, err := ioutil.ReadFile(testdata.Path(rootsPath))
if err != nil {
t.Fatalf("ioutil.ReadFile(%q) failed: %v", rootsPath, err)
}
roots := x509.NewCertPool()
if !roots.AppendCertsFromPEM(b) {
t.Fatal("failed to append certificates")
}
return credentials.NewTLS(&tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: roots,
})
}
const (
wantHeaderData = "headerData"
staleHeaderData = "staleHeaderData"
)
var (
keyMap = map[string]string{
"k1": "v1",
"k2": "v2",
}
wantTargets = []string{"us_east_1.firestore.googleapis.com"}
lookupRequest = &rlspb.RouteLookupRequest{
TargetType: "grpc",
KeyMap: keyMap,
Reason: rlspb.RouteLookupRequest_REASON_MISS,
StaleHeaderData: staleHeaderData,
}
lookupResponse = &e2e.RouteLookupResponse{
Resp: &rlspb.RouteLookupResponse{
Targets: wantTargets,
HeaderData: wantHeaderData,
},
}
)
func testControlChannelCredsSuccess(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions) {
// Start an RLS server and set the throttler to never throttle requests.
rlsServer, _ := setupFakeRLSServer(t, nil, sopts...)
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
// Setup the RLS server to respond with a valid response.
rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *e2e.RouteLookupResponse {
return lookupResponse
})
// Verify that the request received by the RLS matches the expected one.
rlsServer.SetRequestCallback(func(got *rlspb.RouteLookupRequest) {
if diff := cmp.Diff(lookupRequest, got, cmp.Comparer(proto.Equal)); diff != "" {
t.Errorf("RouteLookupRequest diff (-want, +got):\n%s", diff)
}
})
// Create a control channel to the fake server.
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, bopts, nil)
if err != nil {
t.Fatalf("Failed to create control channel to RLS server: %v", err)
}
defer ctrlCh.close()
// Perform the lookup and expect a successful callback invocation.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
errCh := make(chan error, 1)
ctrlCh.lookup(keyMap, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(targets []string, headerData string, err error) {
if err != nil {
errCh <- fmt.Errorf("rlsClient.lookup() failed with err: %v", err)
return
}
if !cmp.Equal(targets, wantTargets) || headerData != wantHeaderData {
errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, headerData, wantTargets, wantHeaderData)
return
}
errCh <- nil
})
select {
case <-ctx.Done():
t.Fatal("timeout when waiting for lookup callback to be invoked")
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}
// TestControlChannelCredsSuccess tests creation of the control channel with
// different credentials, which are expected to succeed.
func (s) TestControlChannelCredsSuccess(t *testing.T) {
serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
tests := []struct {
name string
sopts []grpc.ServerOption
bopts balancer.BuildOptions
}{
{
name: "insecure",
sopts: nil,
bopts: balancer.BuildOptions{},
},
{
name: "transport creds only",
sopts: []grpc.ServerOption{grpc.Creds(serverCreds)},
bopts: balancer.BuildOptions{
DialCreds: clientCreds,
Authority: "x.test.example.com",
},
},
{
name: "creds bundle",
sopts: []grpc.ServerOption{
grpc.Creds(serverCreds),
grpc.UnaryInterceptor(callCredsValidatingServerInterceptor),
},
bopts: balancer.BuildOptions{
CredsBundle: &testCredsBundle{
transportCreds: clientCreds,
callCreds: &testPerRPCCredentials{callCreds: perRPCCredsData},
},
Authority: "x.test.example.com",
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
testControlChannelCredsSuccess(t, test.sopts, test.bopts)
})
}
}
func testControlChannelCredsFailure(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions, wantCode codes.Code, wantErr string) {
// StartFakeRouteLookupServer a fake server.
//
// Start an RLS server and set the throttler to never throttle requests. The
// creds failures happen before the RPC handler on the server is invoked.
// So, there is need to setup the request and responses on the fake server.
rlsServer, _ := setupFakeRLSServer(t, nil, sopts...)
overrideAdaptiveThrottler(t, neverThrottlingThrottler())
// Create the control channel to the fake server.
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, bopts, nil)
if err != nil {
t.Fatalf("Failed to create control channel to RLS server: %v", err)
}
defer ctrlCh.close()
// Perform the lookup and expect the callback to be invoked with an error.
errCh := make(chan error)
ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
if st, ok := status.FromError(err); !ok || st.Code() != wantCode || !strings.Contains(st.String(), wantErr) {
errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, wantCode: %v, wantErr: %s", err, wantCode, wantErr)
return
}
errCh <- nil
})
select {
case <-time.After(defaultTestTimeout):
t.Fatal("timeout when waiting for lookup callback to be invoked")
case err := <-errCh:
if err != nil {
t.Fatal(err)
}
}
}
// TestControlChannelCredsFailure tests creation of the control channel with
// different credentials, which are expected to fail.
func (s) TestControlChannelCredsFailure(t *testing.T) {
serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
tests := []struct {
name string
sopts []grpc.ServerOption
bopts balancer.BuildOptions
wantCode codes.Code
wantErr string
}{
{
name: "transport creds authority mismatch",
sopts: []grpc.ServerOption{grpc.Creds(serverCreds)},
bopts: balancer.BuildOptions{
DialCreds: clientCreds,
Authority: "authority-mismatch",
},
wantCode: codes.Unavailable,
wantErr: "transport: authentication handshake failed: x509: certificate is valid for *.test.example.com, not authority-mismatch",
},
{
name: "transport creds handshake failure",
sopts: nil, // server expects insecure connection
bopts: balancer.BuildOptions{
DialCreds: clientCreds,
Authority: "x.test.example.com",
},
wantCode: codes.Unavailable,
wantErr: "transport: authentication handshake failed: tls: first record does not look like a TLS handshake",
},
{
name: "call creds mismatch",
sopts: []grpc.ServerOption{
grpc.Creds(serverCreds),
grpc.UnaryInterceptor(callCredsValidatingServerInterceptor), // server expects call creds
},
bopts: balancer.BuildOptions{
CredsBundle: &testCredsBundle{
transportCreds: clientCreds,
callCreds: &testPerRPCCredentials{}, // sends no call creds
},
Authority: "x.test.example.com",
},
wantCode: codes.PermissionDenied,
wantErr: "didn't find call creds",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
testControlChannelCredsFailure(t, test.sopts, test.bopts, test.wantCode, test.wantErr)
})
}
}
type unsupportedCredsBundle struct {
credentials.Bundle
}
func (*unsupportedCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) {
return nil, fmt.Errorf("unsupported mode: %v", mode)
}
// TestNewControlChannelUnsupportedCredsBundle tests the case where the control
// channel is configured with a bundle which does not support the mode we use.
func (s) TestNewControlChannelUnsupportedCredsBundle(t *testing.T) {
rlsServer, _ := setupFakeRLSServer(t, nil)
// Create the control channel to the fake server.
ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{CredsBundle: &unsupportedCredsBundle{}}, nil)
if err == nil {
ctrlCh.close()
t.Fatal("newControlChannel succeeded when expected to fail")
}
}

View File

@ -0,0 +1,327 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rls
import (
"context"
"net"
"strings"
"sync"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer/rls/internal/test/e2e"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancergroup"
"google.golang.org/grpc/internal/grpctest"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/status"
testgrpc "google.golang.org/grpc/test/grpc_testing"
testpb "google.golang.org/grpc/test/grpc_testing"
"google.golang.org/protobuf/types/known/durationpb"
)
// TODO(easwars): Remove this once all RLS code is merged.
//lint:file-ignore U1000 Ignore all unused code, not all code is merged yet.
const (
defaultTestTimeout = 5 * time.Second
defaultTestShortTimeout = 100 * time.Millisecond
)
func init() {
balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond
}
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
// connWrapper wraps a net.Conn and pushes on a channel when closed.
type connWrapper struct {
net.Conn
closeCh *testutils.Channel
}
func (cw *connWrapper) Close() error {
err := cw.Conn.Close()
cw.closeCh.Replace(nil)
return err
}
// listenerWrapper wraps a net.Listener and the returned net.Conn.
//
// It pushes on a channel whenever it accepts a new connection.
type listenerWrapper struct {
net.Listener
newConnCh *testutils.Channel
}
func (l *listenerWrapper) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
closeCh := testutils.NewChannel()
conn := &connWrapper{Conn: c, closeCh: closeCh}
l.newConnCh.Send(conn)
return conn, nil
}
func newListenerWrapper(t *testing.T, lis net.Listener) *listenerWrapper {
if lis == nil {
var err error
lis, err = testutils.LocalTCPListener()
if err != nil {
t.Fatal(err)
}
}
return &listenerWrapper{
Listener: lis,
newConnCh: testutils.NewChannel(),
}
}
// fakeBackoffStrategy is a fake implementation of the backoff.Strategy
// interface, for tests to inject the backoff duration.
type fakeBackoffStrategy struct {
backoff time.Duration
}
func (f *fakeBackoffStrategy) Backoff(retries int) time.Duration {
return f.backoff
}
// fakeThrottler is a fake implementation of the adaptiveThrottler interface.
type fakeThrottler struct {
throttleFunc func() bool
}
func (f *fakeThrottler) ShouldThrottle() bool { return f.throttleFunc() }
func (f *fakeThrottler) RegisterBackendResponse(bool) {}
// alwaysThrottlingThrottler returns a fake throttler which always throttles.
func alwaysThrottlingThrottler() *fakeThrottler {
return &fakeThrottler{throttleFunc: func() bool { return true }}
}
// neverThrottlingThrottler returns a fake throttler which never throttles.
func neverThrottlingThrottler() *fakeThrottler {
return &fakeThrottler{throttleFunc: func() bool { return false }}
}
// oneTimeAllowingThrottler returns a fake throttler which does not throttle the
// first request, but throttles everything that comes after. This is useful for
// tests which need to set up a valid cache entry before testing other cases.
func oneTimeAllowingThrottler() *fakeThrottler {
var once sync.Once
return &fakeThrottler{
throttleFunc: func() bool {
throttle := true
once.Do(func() { throttle = false })
return throttle
},
}
}
func overrideAdaptiveThrottler(t *testing.T, f *fakeThrottler) {
origAdaptiveThrottler := newAdaptiveThrottler
newAdaptiveThrottler = func() adaptiveThrottler { return f }
t.Cleanup(func() { newAdaptiveThrottler = origAdaptiveThrottler })
}
// setupFakeRLSServer starts and returns a fake RouteLookupService server
// listening on the given listener or on a random local port. Also returns a
// channel for tests to get notified whenever the RouteLookup RPC is invoked on
// the fake server.
//
// This function sets up the fake server to respond with an empty response for
// the RouteLookup RPCs. Tests can override this by calling the
// SetResponseCallback() method on the returned fake server.
func setupFakeRLSServer(t *testing.T, lis net.Listener, opts ...grpc.ServerOption) (*e2e.FakeRouteLookupServer, chan struct{}) {
s, cancel := e2e.StartFakeRouteLookupServer(t, lis, opts...)
t.Logf("Started fake RLS server at %q", s.Address)
ch := make(chan struct{}, 1)
s.SetRequestCallback(func(request *rlspb.RouteLookupRequest) {
select {
case ch <- struct{}{}:
default:
}
})
t.Cleanup(cancel)
return s, ch
}
// buildBasicRLSConfig constructs a basic service config for the RLS LB policy
// which header matching rules. This expects the passed child policy name to
// have been registered by the caller.
func buildBasicRLSConfig(childPolicyName, rlsServerAddress string) *e2e.RLSConfig {
return &e2e.RLSConfig{
RouteLookupConfig: &rlspb.RouteLookupConfig{
GrpcKeybuilders: []*rlspb.GrpcKeyBuilder{
{
Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "grpc.testing.TestService"}},
Headers: []*rlspb.NameMatcher{
{Key: "k1", Names: []string{"n1"}},
{Key: "k2", Names: []string{"n2"}},
},
},
},
LookupService: rlsServerAddress,
LookupServiceTimeout: durationpb.New(defaultTestTimeout),
CacheSizeBytes: 1024,
},
ChildPolicy: &internalserviceconfig.BalancerConfig{Name: childPolicyName},
ChildPolicyConfigTargetFieldName: e2e.RLSChildPolicyTargetNameField,
}
}
// buildBasicRLSConfigWithChildPolicy constructs a very basic service config for
// the RLS LB policy. It also registers a test LB policy which is capable of
// being a child of the RLS LB policy.
func buildBasicRLSConfigWithChildPolicy(t *testing.T, childPolicyName, rlsServerAddress string) *e2e.RLSConfig {
childPolicyName = "test-child-policy" + childPolicyName
e2e.RegisterRLSChildPolicy(childPolicyName, nil)
t.Logf("Registered child policy with name %q", childPolicyName)
return &e2e.RLSConfig{
RouteLookupConfig: &rlspb.RouteLookupConfig{
GrpcKeybuilders: []*rlspb.GrpcKeyBuilder{{Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "grpc.testing.TestService"}}}},
LookupService: rlsServerAddress,
LookupServiceTimeout: durationpb.New(defaultTestTimeout),
CacheSizeBytes: 1024,
},
ChildPolicy: &internalserviceconfig.BalancerConfig{Name: childPolicyName},
ChildPolicyConfigTargetFieldName: e2e.RLSChildPolicyTargetNameField,
}
}
// startBackend starts a backend implementing the TestService on a local port.
// It returns a channel for tests to get notified whenever an RPC is invoked on
// the backend. This allows tests to ensure that RPCs reach expected backends.
// Also returns the address of the backend.
func startBackend(t *testing.T, sopts ...grpc.ServerOption) (rpcCh chan struct{}, address string) {
t.Helper()
rpcCh = make(chan struct{}, 1)
backend := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
select {
case rpcCh <- struct{}{}:
default:
}
return &testpb.Empty{}, nil
},
}
if err := backend.StartServer(sopts...); err != nil {
t.Fatalf("Failed to start backend: %v", err)
}
t.Logf("Started TestService backend at: %q", backend.Address)
t.Cleanup(func() { backend.Stop() })
return rpcCh, backend.Address
}
// startManualResolverWithConfig registers and returns a manual resolver which
// pushes the RLS LB policy's service config on the channel.
func startManualResolverWithConfig(t *testing.T, rlsConfig *e2e.RLSConfig) *manual.Resolver {
t.Helper()
scJSON, err := rlsConfig.ServiceConfigJSON()
if err != nil {
t.Fatal(err)
}
sc := internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)(scJSON)
r := manual.NewBuilderWithScheme("rls-e2e")
r.InitialState(resolver.State{ServiceConfig: sc})
t.Cleanup(r.Close)
return r
}
// makeTestRPCAndExpectItToReachBackend is a test helper function which makes
// the EmptyCall RPC on the given ClientConn and verifies that it reaches a
// backend. The latter is accomplished by listening on the provided channel
// which gets pushed to whenever the backend in question gets an RPC.
func makeTestRPCAndExpectItToReachBackend(ctx context.Context, t *testing.T, cc *grpc.ClientConn, ch chan struct{}) {
t.Helper()
client := testgrpc.NewTestServiceClient(cc)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("TestService/EmptyCall() failed with error: %v", err)
}
select {
case <-ctx.Done():
t.Fatalf("Timeout when waiting for backend to receive RPC")
case <-ch:
}
}
// makeTestRPCAndVerifyError is a test helper function which makes the EmptyCall
// RPC on the given ClientConn and verifies that the RPC fails with the given
// status code and error.
func makeTestRPCAndVerifyError(ctx context.Context, t *testing.T, cc *grpc.ClientConn, wantCode codes.Code, wantErr error) {
t.Helper()
client := testgrpc.NewTestServiceClient(cc)
_, err := client.EmptyCall(ctx, &testpb.Empty{})
if err == nil {
t.Fatal("TestService/EmptyCall() succeeded when expected to fail")
}
if code := status.Code(err); code != wantCode {
t.Fatalf("TestService/EmptyCall() returned code: %v, want: %v", code, wantCode)
}
if wantErr != nil && !strings.Contains(err.Error(), wantErr.Error()) {
t.Fatalf("TestService/EmptyCall() returned err: %v, want: %v", err, wantErr)
}
}
// verifyRLSRequest is a test helper which listens on a channel to see if an RLS
// request was received by the fake RLS server. Based on whether the test
// expects a request to be sent out or not, it uses a different timeout.
func verifyRLSRequest(t *testing.T, ch chan struct{}, wantRequest bool) {
t.Helper()
if wantRequest {
select {
case <-time.After(defaultTestTimeout):
t.Fatalf("Timeout when waiting for an RLS request to be sent out")
case <-ch:
}
} else {
select {
case <-time.After(defaultTestShortTimeout):
case <-ch:
t.Fatalf("RLS request sent out when not expecting one")
}
}
}

View File

@ -1,147 +0,0 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rls
import (
"errors"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/rls/internal/cache"
"google.golang.org/grpc/balancer/rls/internal/keys"
"google.golang.org/grpc/metadata"
)
var errRLSThrottled = errors.New("RLS call throttled at client side")
// RLS rlsPicker selects the subConn to be used for a particular RPC. It does
// not manage subConns directly and usually deletegates to pickers provided by
// child policies.
//
// The RLS LB policy creates a new rlsPicker object whenever its ServiceConfig
// is updated and provides a bunch of hooks for the rlsPicker to get the latest
// state that it can used to make its decision.
type rlsPicker struct {
// The keyBuilder map used to generate RLS keys for the RPC. This is built
// by the LB policy based on the received ServiceConfig.
kbm keys.BuilderMap
// Endpoint from the user's original dial target. Used to set the `host_key`
// field in `extra_keys`.
origEndpoint string
// The following hooks are setup by the LB policy to enable the rlsPicker to
// access state stored in the policy. This approach has the following
// advantages:
// 1. The rlsPicker is loosely coupled with the LB policy in the sense that
// updates happening on the LB policy like the receipt of an RLS
// response, or an update to the default rlsPicker etc are not explicitly
// pushed to the rlsPicker, but are readily available to the rlsPicker
// when it invokes these hooks. And the LB policy takes care of
// synchronizing access to these shared state.
// 2. It makes unit testing the rlsPicker easy since any number of these
// hooks could be overridden.
// readCache is used to read from the data cache and the pending request
// map in an atomic fashion. The first return parameter is the entry in the
// data cache, and the second indicates whether an entry for the same key
// is present in the pending cache.
readCache func(cache.Key) (*cache.Entry, bool)
// shouldThrottle decides if the current RPC should be throttled at the
// client side. It uses an adaptive throttling algorithm.
shouldThrottle func() bool
// startRLS kicks off an RLS request in the background for the provided RPC
// path and keyMap. An entry in the pending request map is created before
// sending out the request and an entry in the data cache is created or
// updated upon receipt of a response. See implementation in the LB policy
// for details.
startRLS func(string, keys.KeyMap)
// defaultPick enables the rlsPicker to delegate the pick decision to the
// rlsPicker returned by the child LB policy pointing to the default target
// specified in the service config.
defaultPick func(balancer.PickInfo) (balancer.PickResult, error)
}
// Pick makes the routing decision for every outbound RPC.
func (p *rlsPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
// Build the request's keys using the key builders from LB config.
md, _ := metadata.FromOutgoingContext(info.Ctx)
km := p.kbm.RLSKey(md, p.origEndpoint, info.FullMethodName)
// We use the LB policy hook to read the data cache and the pending request
// map (whether or not an entry exists) for the RPC path and the generated
// RLS keys. We will end up kicking off an RLS request only if there is no
// pending request for the current RPC path and keys, and either we didn't
// find an entry in the data cache or the entry was stale and it wasn't in
// backoff.
startRequest := false
now := time.Now()
entry, pending := p.readCache(cache.Key{Path: info.FullMethodName, KeyMap: km.Str})
if entry == nil {
startRequest = true
} else {
entry.Mu.Lock()
defer entry.Mu.Unlock()
if entry.StaleTime.Before(now) && entry.BackoffTime.Before(now) {
// This is the proactive cache refresh.
startRequest = true
}
}
if startRequest && !pending {
if p.shouldThrottle() {
// The entry doesn't exist or has expired and the new RLS request
// has been throttled. Treat it as an error and delegate to default
// pick, if one exists, or fail the pick.
if entry == nil || entry.ExpiryTime.Before(now) {
if p.defaultPick != nil {
return p.defaultPick(info)
}
return balancer.PickResult{}, errRLSThrottled
}
// The proactive refresh has been throttled. Nothing to worry, just
// keep using the existing entry.
} else {
p.startRLS(info.FullMethodName, km)
}
}
if entry != nil {
if entry.ExpiryTime.After(now) {
// This is the jolly good case where we have found a valid entry in
// the data cache. We delegate to the LB policy associated with
// this cache entry.
return entry.ChildPicker.Pick(info)
} else if entry.BackoffTime.After(now) {
// The entry has expired, but is in backoff. We delegate to the
// default pick, if one exists, or return the error from the last
// failed RLS request for this entry.
if p.defaultPick != nil {
return p.defaultPick(info)
}
return balancer.PickResult{}, entry.CallStatus
}
}
// We get here only in the following cases:
// * No data cache entry or expired entry, RLS request sent out
// * No valid data cache entry and Pending cache entry exists
// We need to queue to pick which will be handled once the RLS response is
// received.
return balancer.PickResult{}, balancer.ErrNoSubConnAvailable
}

View File

@ -1,615 +0,0 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rls
import (
"context"
"errors"
"fmt"
"math"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/rls/internal/cache"
"google.golang.org/grpc/balancer/rls/internal/keys"
"google.golang.org/grpc/internal/grpcrand"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
)
const defaultTestMaxAge = 5 * time.Second
// initKeyBuilderMap initializes a keyBuilderMap of the form:
// {
// "gFoo": "k1=n1",
// "gBar/method1": "k2=n21,n22"
// "gFoobar": "k3=n3",
// }
func initKeyBuilderMap() (keys.BuilderMap, error) {
kb1 := &rlspb.GrpcKeyBuilder{
Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "gFoo"}},
Headers: []*rlspb.NameMatcher{{Key: "k1", Names: []string{"n1"}}},
}
kb2 := &rlspb.GrpcKeyBuilder{
Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "gBar", Method: "method1"}},
Headers: []*rlspb.NameMatcher{{Key: "k2", Names: []string{"n21", "n22"}}},
}
kb3 := &rlspb.GrpcKeyBuilder{
Names: []*rlspb.GrpcKeyBuilder_Name{{Service: "gFoobar"}},
Headers: []*rlspb.NameMatcher{{Key: "k3", Names: []string{"n3"}}},
}
return keys.MakeBuilderMap(&rlspb.RouteLookupConfig{
GrpcKeybuilders: []*rlspb.GrpcKeyBuilder{kb1, kb2, kb3},
})
}
// fakeSubConn embeds the balancer.SubConn interface and contains an id which
// helps verify that the expected subConn was returned by the rlsPicker.
type fakeSubConn struct {
balancer.SubConn
id int
}
// fakePicker sends a PickResult with a fakeSubConn with the configured id.
type fakePicker struct {
id int
}
func (p *fakePicker) Pick(_ balancer.PickInfo) (balancer.PickResult, error) {
return balancer.PickResult{SubConn: &fakeSubConn{id: p.id}}, nil
}
// newFakePicker returns a fakePicker configured with a random ID. The subConns
// returned by this picker are of type fakefakeSubConn, and contain the same
// random ID, which tests can use to verify.
func newFakePicker() *fakePicker {
return &fakePicker{id: grpcrand.Intn(math.MaxInt32)}
}
func verifySubConn(sc balancer.SubConn, wantID int) error {
fsc, ok := sc.(*fakeSubConn)
if !ok {
return fmt.Errorf("Pick() returned a SubConn of type %T, want %T", sc, &fakeSubConn{})
}
if fsc.id != wantID {
return fmt.Errorf("Pick() returned SubConn %d, want %d", fsc.id, wantID)
}
return nil
}
// TestPickKeyBuilder verifies the different possible scenarios for forming an
// RLS key for an incoming RPC.
func TestPickKeyBuilder(t *testing.T) {
kbm, err := initKeyBuilderMap()
if err != nil {
t.Fatalf("Failed to create keyBuilderMap: %v", err)
}
tests := []struct {
desc string
rpcPath string
md metadata.MD
wantKey cache.Key
}{
{
desc: "non existent service in keyBuilder map",
rpcPath: "/gNonExistentService/method",
md: metadata.New(map[string]string{"n1": "v1", "n3": "v3"}),
wantKey: cache.Key{Path: "/gNonExistentService/method", KeyMap: ""},
},
{
desc: "no metadata in incoming context",
rpcPath: "/gFoo/method",
md: metadata.MD{},
wantKey: cache.Key{Path: "/gFoo/method", KeyMap: ""},
},
{
desc: "keyBuilderMatch",
rpcPath: "/gFoo/method",
md: metadata.New(map[string]string{"n1": "v1", "n3": "v3"}),
wantKey: cache.Key{Path: "/gFoo/method", KeyMap: "k1=v1"},
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
randID := grpcrand.Intn(math.MaxInt32)
p := rlsPicker{
kbm: kbm,
readCache: func(key cache.Key) (*cache.Entry, bool) {
if !cmp.Equal(key, test.wantKey) {
t.Fatalf("rlsPicker using cacheKey %v, want %v", key, test.wantKey)
}
now := time.Now()
return &cache.Entry{
ExpiryTime: now.Add(defaultTestMaxAge),
StaleTime: now.Add(defaultTestMaxAge),
// Cache entry is configured with a child policy whose
// rlsPicker always returns an empty PickResult and nil
// error.
ChildPicker: &fakePicker{id: randID},
}, false
},
// The other hooks are not set here because they are not expected to be
// invoked for these cases and if they get invoked, they will panic.
}
gotResult, err := p.Pick(balancer.PickInfo{
FullMethodName: test.rpcPath,
Ctx: metadata.NewOutgoingContext(context.Background(), test.md),
})
if err != nil {
t.Fatalf("Pick() failed with error: %v", err)
}
sc, ok := gotResult.SubConn.(*fakeSubConn)
if !ok {
t.Fatalf("Pick() returned a SubConn of type %T, want %T", gotResult.SubConn, &fakeSubConn{})
}
if sc.id != randID {
t.Fatalf("Pick() returned SubConn %d, want %d", sc.id, randID)
}
})
}
}
// TestPick_DataCacheMiss_PendingCacheMiss verifies different Pick scenarios
// where the entry is neither found in the data cache nor in the pending cache.
func TestPick_DataCacheMiss_PendingCacheMiss(t *testing.T) {
const (
rpcPath = "/gFoo/method"
wantKeyMapStr = "k1=v1"
)
kbm, err := initKeyBuilderMap()
if err != nil {
t.Fatalf("Failed to create keyBuilderMap: %v", err)
}
md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"})
wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr}
tests := []struct {
desc string
// Whether or not a default target is configured.
defaultPickExists bool
// Whether or not the RLS request should be throttled.
throttle bool
// Whether or not the test is expected to make a new RLS request.
wantRLSRequest bool
// Expected error returned by the rlsPicker under test.
wantErr error
}{
{
desc: "rls request throttled with default pick",
defaultPickExists: true,
throttle: true,
},
{
desc: "rls request throttled without default pick",
throttle: true,
wantErr: errRLSThrottled,
},
{
desc: "rls request not throttled",
wantRLSRequest: true,
wantErr: balancer.ErrNoSubConnAvailable,
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
rlsCh := testutils.NewChannel()
defaultPicker := newFakePicker()
p := rlsPicker{
kbm: kbm,
// Cache lookup fails, no pending entry.
readCache: func(key cache.Key) (*cache.Entry, bool) {
if !cmp.Equal(key, wantKey) {
t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey)
}
return nil, false
},
shouldThrottle: func() bool { return test.throttle },
startRLS: func(path string, km keys.KeyMap) {
if !test.wantRLSRequest {
rlsCh.Send(errors.New("RLS request attempted when none was expected"))
return
}
if path != rpcPath {
rlsCh.Send(fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath))
return
}
if km.Str != wantKeyMapStr {
rlsCh.Send(fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr))
return
}
rlsCh.Send(nil)
},
}
if test.defaultPickExists {
p.defaultPick = defaultPicker.Pick
}
gotResult, err := p.Pick(balancer.PickInfo{
FullMethodName: rpcPath,
Ctx: metadata.NewOutgoingContext(context.Background(), md),
})
if err != test.wantErr {
t.Fatalf("Pick() returned error {%v}, want {%v}", err, test.wantErr)
}
// If the test specified that a new RLS request should be made,
// verify it.
if test.wantRLSRequest {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if rlsErr, err := rlsCh.Receive(ctx); err != nil || rlsErr != nil {
t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err)
}
}
if test.wantErr != nil {
return
}
// We get here only for cases where we expect the pick to be
// delegated to the default picker.
if err := verifySubConn(gotResult.SubConn, defaultPicker.id); err != nil {
t.Fatal(err)
}
})
}
}
// TestPick_DataCacheMiss_PendingCacheMiss verifies different Pick scenarios
// where the entry is not found in the data cache, but there is a entry in the
// pending cache. For all of these scenarios, no new RLS request will be sent.
func TestPick_DataCacheMiss_PendingCacheHit(t *testing.T) {
const (
rpcPath = "/gFoo/method"
wantKeyMapStr = "k1=v1"
)
kbm, err := initKeyBuilderMap()
if err != nil {
t.Fatalf("Failed to create keyBuilderMap: %v", err)
}
md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"})
wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr}
tests := []struct {
desc string
defaultPickExists bool
}{
{
desc: "default pick exists",
defaultPickExists: true,
},
{
desc: "default pick does not exists",
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
rlsCh := testutils.NewChannel()
p := rlsPicker{
kbm: kbm,
// Cache lookup fails, pending entry exists.
readCache: func(key cache.Key) (*cache.Entry, bool) {
if !cmp.Equal(key, wantKey) {
t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey)
}
return nil, true
},
// Never throttle. We do not expect an RLS request to be sent out anyways.
shouldThrottle: func() bool { return false },
startRLS: func(_ string, _ keys.KeyMap) {
rlsCh.Send(nil)
},
}
if test.defaultPickExists {
p.defaultPick = func(info balancer.PickInfo) (balancer.PickResult, error) {
// We do not expect the default picker to be invoked at all.
// So, if we get here, the test will fail, because it
// expects the pick to be queued.
return balancer.PickResult{}, nil
}
}
if _, err := p.Pick(balancer.PickInfo{
FullMethodName: rpcPath,
Ctx: metadata.NewOutgoingContext(context.Background(), md),
}); err != balancer.ErrNoSubConnAvailable {
t.Fatalf("Pick() returned error {%v}, want {%v}", err, balancer.ErrNoSubConnAvailable)
}
// Make sure that no RLS request was sent out.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := rlsCh.Receive(ctx); err != context.DeadlineExceeded {
t.Fatalf("RLS request sent out when pending entry exists")
}
})
}
}
// TestPick_DataCacheHit_PendingCacheMiss verifies different Pick scenarios
// where the entry is found in the data cache, and there is no entry in the
// pending cache. This includes cases where the entry in the data cache is
// stale, expired or in backoff.
func TestPick_DataCacheHit_PendingCacheMiss(t *testing.T) {
const (
rpcPath = "/gFoo/method"
wantKeyMapStr = "k1=v1"
)
kbm, err := initKeyBuilderMap()
if err != nil {
t.Fatalf("Failed to create keyBuilderMap: %v", err)
}
md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"})
wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr}
rlsLastErr := errors.New("last RLS request failed")
tests := []struct {
desc string
// The cache entry, as returned by the overridden readCache hook.
cacheEntry *cache.Entry
// Whether or not a default target is configured.
defaultPickExists bool
// Whether or not the RLS request should be throttled.
throttle bool
// Whether or not the test is expected to make a new RLS request.
wantRLSRequest bool
// Whether or not the rlsPicker should delegate to the child picker.
wantChildPick bool
// Whether or not the rlsPicker should delegate to the default picker.
wantDefaultPick bool
// Expected error returned by the rlsPicker under test.
wantErr error
}{
{
desc: "valid entry",
cacheEntry: &cache.Entry{
ExpiryTime: time.Now().Add(defaultTestMaxAge),
StaleTime: time.Now().Add(defaultTestMaxAge),
},
wantChildPick: true,
},
{
desc: "entryStale_requestThrottled",
cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)},
throttle: true,
wantChildPick: true,
},
{
desc: "entryStale_requestNotThrottled",
cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)},
wantRLSRequest: true,
wantChildPick: true,
},
{
desc: "entryExpired_requestThrottled_defaultPickExists",
cacheEntry: &cache.Entry{},
throttle: true,
defaultPickExists: true,
wantDefaultPick: true,
},
{
desc: "entryExpired_requestThrottled_defaultPickNotExists",
cacheEntry: &cache.Entry{},
throttle: true,
wantErr: errRLSThrottled,
},
{
desc: "entryExpired_requestNotThrottled",
cacheEntry: &cache.Entry{},
wantRLSRequest: true,
wantErr: balancer.ErrNoSubConnAvailable,
},
{
desc: "entryExpired_backoffNotExpired_defaultPickExists",
cacheEntry: &cache.Entry{
BackoffTime: time.Now().Add(defaultTestMaxAge),
CallStatus: rlsLastErr,
},
defaultPickExists: true,
},
{
desc: "entryExpired_backoffNotExpired_defaultPickNotExists",
cacheEntry: &cache.Entry{
BackoffTime: time.Now().Add(defaultTestMaxAge),
CallStatus: rlsLastErr,
},
wantErr: rlsLastErr,
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
rlsCh := testutils.NewChannel()
childPicker := newFakePicker()
defaultPicker := newFakePicker()
p := rlsPicker{
kbm: kbm,
readCache: func(key cache.Key) (*cache.Entry, bool) {
if !cmp.Equal(key, wantKey) {
t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey)
}
test.cacheEntry.ChildPicker = childPicker
return test.cacheEntry, false
},
shouldThrottle: func() bool { return test.throttle },
startRLS: func(path string, km keys.KeyMap) {
if !test.wantRLSRequest {
rlsCh.Send(errors.New("RLS request attempted when none was expected"))
return
}
if path != rpcPath {
rlsCh.Send(fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath))
return
}
if km.Str != wantKeyMapStr {
rlsCh.Send(fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr))
return
}
rlsCh.Send(nil)
},
}
if test.defaultPickExists {
p.defaultPick = defaultPicker.Pick
}
gotResult, err := p.Pick(balancer.PickInfo{
FullMethodName: rpcPath,
Ctx: metadata.NewOutgoingContext(context.Background(), md),
})
if err != test.wantErr {
t.Fatalf("Pick() returned error {%v}, want {%v}", err, test.wantErr)
}
// If the test specified that a new RLS request should be made,
// verify it.
if test.wantRLSRequest {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if rlsErr, err := rlsCh.Receive(ctx); err != nil || rlsErr != nil {
t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err)
}
}
if test.wantErr != nil {
return
}
// We get here only for cases where we expect the pick to be
// delegated to the child picker or the default picker.
if test.wantChildPick {
if err := verifySubConn(gotResult.SubConn, childPicker.id); err != nil {
t.Fatal(err)
}
}
if test.wantDefaultPick {
if err := verifySubConn(gotResult.SubConn, defaultPicker.id); err != nil {
t.Fatal(err)
}
}
})
}
}
// TestPick_DataCacheHit_PendingCacheHit verifies different Pick scenarios where
// the entry is found both in the data cache and in the pending cache. This
// mostly verifies cases where the entry is stale, but there is already a
// pending RLS request, so no new request should be sent out.
func TestPick_DataCacheHit_PendingCacheHit(t *testing.T) {
const (
rpcPath = "/gFoo/method"
wantKeyMapStr = "k1=v1"
)
kbm, err := initKeyBuilderMap()
if err != nil {
t.Fatalf("Failed to create keyBuilderMap: %v", err)
}
md := metadata.New(map[string]string{"n1": "v1", "n3": "v3"})
wantKey := cache.Key{Path: rpcPath, KeyMap: wantKeyMapStr}
tests := []struct {
desc string
// The cache entry, as returned by the overridden readCache hook.
cacheEntry *cache.Entry
// Whether or not a default target is configured.
defaultPickExists bool
// Expected error returned by the rlsPicker under test.
wantErr error
}{
{
desc: "stale entry",
cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)},
},
{
desc: "stale entry with default picker",
cacheEntry: &cache.Entry{ExpiryTime: time.Now().Add(defaultTestMaxAge)},
defaultPickExists: true,
},
{
desc: "entryExpired_defaultPickExists",
cacheEntry: &cache.Entry{},
defaultPickExists: true,
wantErr: balancer.ErrNoSubConnAvailable,
},
{
desc: "entryExpired_defaultPickNotExists",
cacheEntry: &cache.Entry{},
wantErr: balancer.ErrNoSubConnAvailable,
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
rlsCh := testutils.NewChannel()
childPicker := newFakePicker()
p := rlsPicker{
kbm: kbm,
readCache: func(key cache.Key) (*cache.Entry, bool) {
if !cmp.Equal(key, wantKey) {
t.Fatalf("cache lookup using cacheKey %v, want %v", key, wantKey)
}
test.cacheEntry.ChildPicker = childPicker
return test.cacheEntry, true
},
// Never throttle. We do not expect an RLS request to be sent out anyways.
shouldThrottle: func() bool { return false },
startRLS: func(path string, km keys.KeyMap) {
rlsCh.Send(nil)
},
}
if test.defaultPickExists {
p.defaultPick = func(info balancer.PickInfo) (balancer.PickResult, error) {
// We do not expect the default picker to be invoked at all.
// So, if we get here, we return an error.
return balancer.PickResult{}, errors.New("default picker invoked when expecting a child pick")
}
}
gotResult, err := p.Pick(balancer.PickInfo{
FullMethodName: rpcPath,
Ctx: metadata.NewOutgoingContext(context.Background(), md),
})
if err != test.wantErr {
t.Fatalf("Pick() returned error {%v}, want {%v}", err, test.wantErr)
}
// Make sure that no RLS request was sent out.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := rlsCh.Receive(ctx); err != context.DeadlineExceeded {
t.Fatalf("RLS request sent out when pending entry exists")
}
if test.wantErr != nil {
return
}
// We get here only for cases where we expect the pick to be
// delegated to the child picker.
if err := verifySubConn(gotResult.SubConn, childPicker.id); err != nil {
t.Fatal(err)
}
})
}
}

View File

@ -0,0 +1,20 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package e2e contains utilities for end-to-end RouteLookupService tests.
package e2e

View File

@ -0,0 +1,131 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package e2e
import (
"encoding/json"
"errors"
"fmt"
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
const (
// RLSChildPolicyTargetNameField is a top-level field name to add to the child
// policy's config, whose value is set to the target for the child policy.
RLSChildPolicyTargetNameField = "Backend"
// RLSChildPolicyBadTarget is a value which is considered a bad target by the
// child policy. This is useful to test bad child policy configuration.
RLSChildPolicyBadTarget = "bad-target"
)
// ErrParseConfigBadTarget is the error returned from ParseConfig when the
// backend field is set to RLSChildPolicyBadTarget.
var ErrParseConfigBadTarget = errors.New("backend field set to RLSChildPolicyBadTarget")
// BalancerFuncs is a set of callbacks which get invoked when the corresponding
// method on the child policy is invoked.
type BalancerFuncs struct {
UpdateClientConnState func(cfg *RLSChildPolicyConfig) error
Close func()
}
// RegisterRLSChildPolicy registers a balancer builder with the given name, to
// be used as a child policy for the RLS LB policy.
//
// The child policy uses a pickfirst balancer under the hood to send all traffic
// to the single backend specified by the `RLSChildPolicyTargetNameField` field
// in its configuration which looks like: {"Backend": "Backend-address"}.
func RegisterRLSChildPolicy(name string, bf *BalancerFuncs) {
balancer.Register(bb{name: name, bf: bf})
}
type bb struct {
name string
bf *BalancerFuncs
}
func (bb bb) Name() string { return bb.name }
func (bb bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
pf := balancer.Get(grpc.PickFirstBalancerName)
b := &bal{
Balancer: pf.Build(cc, opts),
bf: bb.bf,
done: grpcsync.NewEvent(),
}
go b.run()
return b
}
func (bb bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) {
cfg := &RLSChildPolicyConfig{}
if err := json.Unmarshal(c, cfg); err != nil {
return nil, err
}
if cfg.Backend == RLSChildPolicyBadTarget {
return nil, ErrParseConfigBadTarget
}
return cfg, nil
}
type bal struct {
balancer.Balancer
bf *BalancerFuncs
done *grpcsync.Event
}
// RLSChildPolicyConfig is the LB config for the test child policy.
type RLSChildPolicyConfig struct {
serviceconfig.LoadBalancingConfig
Backend string // The target for which this child policy was created.
Random string // A random field to test child policy config changes.
}
func (b *bal) UpdateClientConnState(c balancer.ClientConnState) error {
cfg, ok := c.BalancerConfig.(*RLSChildPolicyConfig)
if !ok {
return fmt.Errorf("received balancer config of type %T, want %T", c.BalancerConfig, &RLSChildPolicyConfig{})
}
if b.bf != nil && b.bf.UpdateClientConnState != nil {
b.bf.UpdateClientConnState(cfg)
}
return b.Balancer.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Addresses: []resolver.Address{{Addr: cfg.Backend}}},
})
}
func (b *bal) Close() {
b.Balancer.Close()
if b.bf != nil && b.bf.Close != nil {
b.bf.Close()
}
b.done.Fire()
}
// run is a dummy goroutine to make sure that child policies are closed at the
// end of tests. If they are not closed, these goroutines will be picked up by
// the leakcheker and tests will fail.
func (b *bal) run() {
<-b.done.Done()
}

View File

@ -0,0 +1,110 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package e2e
import (
"context"
"net"
"sync"
"testing"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
rlsgrpc "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/status"
)
// RouteLookupResponse wraps an RLS response and the associated error to be sent
// to a client when the RouteLookup RPC is invoked.
type RouteLookupResponse struct {
Resp *rlspb.RouteLookupResponse
Err error
}
// FakeRouteLookupServer is a fake implementation of the RouteLookupService.
//
// It is safe for concurrent use.
type FakeRouteLookupServer struct {
rlsgrpc.UnimplementedRouteLookupServiceServer
Address string
mu sync.Mutex
respCb func(context.Context, *rlspb.RouteLookupRequest) *RouteLookupResponse
reqCb func(*rlspb.RouteLookupRequest)
}
// StartFakeRouteLookupServer starts a fake RLS server listening for requests on
// lis. If lis is nil, it creates a new listener on a random local port. The
// returned cancel function should be invoked by the caller upon completion of
// the test.
func StartFakeRouteLookupServer(t *testing.T, lis net.Listener, opts ...grpc.ServerOption) (*FakeRouteLookupServer, func()) {
t.Helper()
if lis == nil {
var err error
lis, err = testutils.LocalTCPListener()
if err != nil {
t.Fatalf("net.Listen() failed: %v", err)
}
}
s := &FakeRouteLookupServer{Address: lis.Addr().String()}
server := grpc.NewServer(opts...)
rlsgrpc.RegisterRouteLookupServiceServer(server, s)
go server.Serve(lis)
return s, func() { server.Stop() }
}
// RouteLookup implements the RouteLookupService.
func (s *FakeRouteLookupServer) RouteLookup(ctx context.Context, req *rlspb.RouteLookupRequest) (*rlspb.RouteLookupResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.reqCb != nil {
s.reqCb(req)
}
if err := ctx.Err(); err != nil {
return nil, status.Error(codes.DeadlineExceeded, err.Error())
}
if s.respCb == nil {
return &rlspb.RouteLookupResponse{}, nil
}
resp := s.respCb(ctx, req)
return resp.Resp, resp.Err
}
// SetResponseCallback sets a callback to be invoked on every RLS request. If
// this callback is set, the response returned by the fake server depends on the
// value returned by the callback. If this callback is not set, the fake server
// responds with an empty response.
func (s *FakeRouteLookupServer) SetResponseCallback(f func(context.Context, *rlspb.RouteLookupRequest) *RouteLookupResponse) {
s.mu.Lock()
s.respCb = f
s.mu.Unlock()
}
// SetRequestCallback sets a callback to be invoked on every RLS request. The
// callback is given the incoming request, and tests can use this to verify that
// the request matches its expectations.
func (s *FakeRouteLookupServer) SetRequestCallback(f func(*rlspb.RouteLookupRequest)) {
s.mu.Lock()
s.reqCb = f
s.mu.Unlock()
}

View File

@ -0,0 +1,100 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package e2e
import (
"errors"
"fmt"
"google.golang.org/grpc/balancer"
rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1"
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/serviceconfig"
"google.golang.org/protobuf/encoding/protojson"
)
// RLSConfig is a utility type to build service config for the RLS LB policy.
type RLSConfig struct {
RouteLookupConfig *rlspb.RouteLookupConfig
ChildPolicy *internalserviceconfig.BalancerConfig
ChildPolicyConfigTargetFieldName string
}
// ServiceConfigJSON generates service config with a load balancing config
// corresponding to the RLS LB policy.
func (c *RLSConfig) ServiceConfigJSON() (string, error) {
m := protojson.MarshalOptions{
Multiline: true,
Indent: " ",
UseProtoNames: true,
}
routeLookupCfg, err := m.Marshal(c.RouteLookupConfig)
if err != nil {
return "", err
}
childPolicy, err := c.ChildPolicy.MarshalJSON()
if err != nil {
return "", err
}
return fmt.Sprintf(`
{
"loadBalancingConfig": [
{
"rls_experimental": {
"routeLookupConfig": %s,
"childPolicy": %s,
"childPolicyConfigTargetFieldName": %q
}
}
]
}`, string(routeLookupCfg), string(childPolicy), c.ChildPolicyConfigTargetFieldName), nil
}
// LoadBalancingConfig generates load balancing config which can used as part of
// a ClientConnState update to the RLS LB policy.
func (c *RLSConfig) LoadBalancingConfig() (serviceconfig.LoadBalancingConfig, error) {
m := protojson.MarshalOptions{
Multiline: true,
Indent: " ",
UseProtoNames: true,
}
routeLookupCfg, err := m.Marshal(c.RouteLookupConfig)
if err != nil {
return nil, err
}
childPolicy, err := c.ChildPolicy.MarshalJSON()
if err != nil {
return nil, err
}
lbConfigJSON := fmt.Sprintf(`
{
"routeLookupConfig": %s,
"childPolicy": %s,
"childPolicyConfigTargetFieldName": %q
}`, string(routeLookupCfg), string(childPolicy), c.ChildPolicyConfigTargetFieldName)
builder := balancer.Get("rls_experimental")
if builder == nil {
return nil, errors.New("balancer builder not found for RLS LB policy")
}
parser := builder.(balancer.ConfigParser)
return parser.ParseConfig([]byte(lbConfigJSON))
}

View File

@ -80,6 +80,14 @@ func (ss *StubServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallSer
// Start starts the server and creates a client connected to it.
func (ss *StubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption) error {
if err := ss.StartServer(sopts...); err != nil {
return err
}
return ss.StartClient(dopts...)
}
// StartServer only starts the server. It does not create a client to it.
func (ss *StubServer) StartServer(sopts ...grpc.ServerOption) error {
if ss.Network == "" {
ss.Network = "tcp"
}
@ -102,7 +110,12 @@ func (ss *StubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption)
go s.Serve(lis)
ss.cleanups = append(ss.cleanups, s.Stop)
ss.S = s
return nil
}
// StartClient creates a client connected to this service that the test may use.
// The newly created client will be available in the Client field of StubServer.
func (ss *StubServer) StartClient(dopts ...grpc.DialOption) error {
opts := append([]grpc.DialOption{grpc.WithInsecure()}, dopts...)
if ss.R != nil {
ss.Target = ss.R.Scheme() + ":///" + ss.Address

View File

@ -83,12 +83,11 @@ func (l *RestartableListener) Addr() net.Addr {
func (l *RestartableListener) Stop() {
l.mu.Lock()
l.stopped = true
tmp := l.conns
l.conns = nil
l.mu.Unlock()
for _, conn := range tmp {
for _, conn := range l.conns {
conn.Close()
}
l.conns = nil
l.mu.Unlock()
}
// Restart gets a previously stopped listener to start accepting connections.