diff --git a/balancer/rls/control_channel.go b/balancer/rls/control_channel.go index f2ad8bc72..28f063e73 100644 --- a/balancer/rls/control_channel.go +++ b/balancer/rls/control_channel.go @@ -29,7 +29,9 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/buffer" internalgrpclog "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/pretty" rlsgrpc "google.golang.org/grpc/internal/proto/grpc_lookup_v1" rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" @@ -55,9 +57,12 @@ type controlChannel struct { // hammering the RLS service while it is overloaded or down. throttler adaptiveThrottler - cc *grpc.ClientConn - client rlsgrpc.RouteLookupServiceClient - logger *internalgrpclog.PrefixLogger + cc *grpc.ClientConn + client rlsgrpc.RouteLookupServiceClient + logger *internalgrpclog.PrefixLogger + connectivityStateCh *buffer.Unbounded + unsubscribe func() + monitorDoneCh chan struct{} } // newControlChannel creates a controlChannel to rlsServerName and uses @@ -65,9 +70,11 @@ type controlChannel struct { // gRPC channel. func newControlChannel(rlsServerName, serviceConfig string, rpcTimeout time.Duration, bOpts balancer.BuildOptions, backToReadyFunc func()) (*controlChannel, error) { ctrlCh := &controlChannel{ - rpcTimeout: rpcTimeout, - backToReadyFunc: backToReadyFunc, - throttler: newAdaptiveThrottler(), + rpcTimeout: rpcTimeout, + backToReadyFunc: backToReadyFunc, + throttler: newAdaptiveThrottler(), + connectivityStateCh: buffer.NewUnbounded(), + monitorDoneCh: make(chan struct{}), } ctrlCh.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[rls-control-channel %p] ", ctrlCh)) @@ -75,17 +82,33 @@ func newControlChannel(rlsServerName, serviceConfig string, rpcTimeout time.Dura if err != nil { return nil, err } - ctrlCh.cc, err = grpc.Dial(rlsServerName, dopts...) + ctrlCh.cc, err = grpc.NewClient(rlsServerName, dopts...) if err != nil { return nil, err } + // Subscribe to connectivity state before connecting to avoid missing initial + // updates, which are only delivered to active subscribers. + ctrlCh.unsubscribe = internal.SubscribeToConnectivityStateChanges.(func(cc *grpc.ClientConn, s grpcsync.Subscriber) func())(ctrlCh.cc, ctrlCh) + ctrlCh.cc.Connect() ctrlCh.client = rlsgrpc.NewRouteLookupServiceClient(ctrlCh.cc) ctrlCh.logger.Infof("Control channel created to RLS server at: %v", rlsServerName) - - go ctrlCh.monitorConnectivityState() + start := make(chan struct{}) + go func() { + close(start) + ctrlCh.monitorConnectivityState() + }() + <-start return ctrlCh, nil } +func (cc *controlChannel) OnMessage(msg any) { + st, ok := msg.(connectivity.State) + if !ok { + panic(fmt.Sprintf("Unexpected message type %T , wanted connectectivity.State type", msg)) + } + cc.connectivityStateCh.Put(st) +} + // dialOpts constructs the dial options for the control plane channel. func (cc *controlChannel) dialOpts(bOpts balancer.BuildOptions, serviceConfig string) ([]grpc.DialOption, error) { // The control plane channel will use the same authority as the parent @@ -97,7 +120,6 @@ func (cc *controlChannel) dialOpts(bOpts balancer.BuildOptions, serviceConfig st if bOpts.Dialer != nil { dopts = append(dopts, grpc.WithContextDialer(bOpts.Dialer)) } - // The control channel will use the channel credentials from the parent // channel, including any call creds associated with the channel creds. var credsOpt grpc.DialOption @@ -133,6 +155,8 @@ func (cc *controlChannel) dialOpts(bOpts balancer.BuildOptions, serviceConfig st func (cc *controlChannel) monitorConnectivityState() { cc.logger.Infof("Starting connectivity state monitoring goroutine") + defer close(cc.monitorDoneCh) + // Since we use two mechanisms to deal with RLS server being down: // - adaptive throttling for the channel as a whole // - exponential backoff on a per-request basis @@ -154,39 +178,45 @@ func (cc *controlChannel) monitorConnectivityState() { // returning only one new picker, regardless of how many backoff timers are // cancelled. - // Using the background context is fine here since we check for the ClientConn - // entering SHUTDOWN and return early in that case. - ctx := context.Background() - - first := true - for { - // Wait for the control channel to become READY. - for s := cc.cc.GetState(); s != connectivity.Ready; s = cc.cc.GetState() { - if s == connectivity.Shutdown { - return - } - cc.cc.WaitForStateChange(ctx, s) + // Wait for the control channel to become READY for the first time. + for s, ok := <-cc.connectivityStateCh.Get(); s != connectivity.Ready; s, ok = <-cc.connectivityStateCh.Get() { + if !ok { + return } - cc.logger.Infof("Connectivity state is READY") - if !first { + cc.connectivityStateCh.Load() + if s == connectivity.Shutdown { + return + } + } + cc.connectivityStateCh.Load() + cc.logger.Infof("Connectivity state is READY") + + for { + s, ok := <-cc.connectivityStateCh.Get() + if !ok { + return + } + cc.connectivityStateCh.Load() + + if s == connectivity.Shutdown { + return + } + if s == connectivity.Ready { cc.logger.Infof("Control channel back to READY") cc.backToReadyFunc() } - first = false - // Wait for the control channel to move out of READY. - cc.cc.WaitForStateChange(ctx, connectivity.Ready) - if cc.cc.GetState() == connectivity.Shutdown { - return - } - cc.logger.Infof("Connectivity state is %s", cc.cc.GetState()) + cc.logger.Infof("Connectivity state is %s", s) } } func (cc *controlChannel) close() { - cc.logger.Infof("Closing control channel") + cc.unsubscribe() + cc.connectivityStateCh.Close() + <-cc.monitorDoneCh cc.cc.Close() + cc.logger.Infof("Shutdown") } type lookupCallback func(targets []string, headerData string, err error)