From 105f61423e7ad35e8295ae6b720f22635f7bee55 Mon Sep 17 00:00:00 2001 From: lyuxuan Date: Thu, 1 Nov 2018 10:49:35 -0700 Subject: [PATCH] health: Client LB channel health checking (#2387) --- balancer/balancer.go | 3 + balancer/base/balancer.go | 5 +- balancer/base/base.go | 12 + balancer/roundrobin/roundrobin.go | 2 +- clientconn.go | 108 +++- dialoptions.go | 9 + rpc_util.go | 2 +- service_config.go | 17 +- stream.go | 295 +++++++++ test/healthcheck_test.go | 955 ++++++++++++++++++++++++++++++ 10 files changed, 1394 insertions(+), 14 deletions(-) create mode 100644 test/healthcheck_test.go diff --git a/balancer/balancer.go b/balancer/balancer.go index ee1703f03..b9a8cbf06 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -94,6 +94,9 @@ type NewSubConnOptions struct { // SubConn. If it's nil, the original creds from grpc DialOptions will be // used. CredsBundle credentials.Bundle + // HealthCheckEnabled indicates whether health check service should be + // enabled on this SubConn + HealthCheckEnabled bool } // ClientConn represents a gRPC ClientConn. diff --git a/balancer/base/balancer.go b/balancer/base/balancer.go index 23d13511b..6f3f0c86e 100644 --- a/balancer/base/balancer.go +++ b/balancer/base/balancer.go @@ -29,6 +29,7 @@ import ( type baseBuilder struct { name string pickerBuilder PickerBuilder + config Config } func (bb *baseBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { @@ -43,6 +44,7 @@ func (bb *baseBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) // ErrNoSubConnAvailable, because when state of a SubConn changes, we // may call UpdateBalancerState with this picker. picker: NewErrPicker(balancer.ErrNoSubConnAvailable), + config: bb.config, } } @@ -60,6 +62,7 @@ type baseBalancer struct { subConns map[resolver.Address]balancer.SubConn scStates map[balancer.SubConn]connectivity.State picker balancer.Picker + config Config } func (b *baseBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { @@ -74,7 +77,7 @@ func (b *baseBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) addrsSet[a] = struct{}{} if _, ok := b.subConns[a]; !ok { // a is a new address (not existing in b.subConns). - sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) + sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{HealthCheckEnabled: b.config.HealthCheck}) if err != nil { grpclog.Warningf("base.baseBalancer: failed to create new SubConn: %v", err) continue diff --git a/balancer/base/base.go b/balancer/base/base.go index 012ace2f2..34b1f2994 100644 --- a/balancer/base/base.go +++ b/balancer/base/base.go @@ -45,8 +45,20 @@ type PickerBuilder interface { // NewBalancerBuilder returns a balancer builder. The balancers // built by this builder will use the picker builder to build pickers. func NewBalancerBuilder(name string, pb PickerBuilder) balancer.Builder { + return NewBalancerBuilderWithConfig(name, pb, Config{}) +} + +// Config contains the config info about the base balancer builder. +type Config struct { + // HealthCheck indicates whether health checking should be enabled for this specific balancer. + HealthCheck bool +} + +// NewBalancerBuilderWithConfig returns a base balancer builder configured by the provided config. +func NewBalancerBuilderWithConfig(name string, pb PickerBuilder, config Config) balancer.Builder { return &baseBuilder{ name: name, pickerBuilder: pb, + config: config, } } diff --git a/balancer/roundrobin/roundrobin.go b/balancer/roundrobin/roundrobin.go index 2eda0a1c2..1da04e693 100644 --- a/balancer/roundrobin/roundrobin.go +++ b/balancer/roundrobin/roundrobin.go @@ -36,7 +36,7 @@ const Name = "round_robin" // newBuilder creates a new roundrobin balancer builder. func newBuilder() balancer.Builder { - return base.NewBalancerBuilder(Name, &rrPickerBuilder{}) + return base.NewBalancerBuilderWithConfig(Name, &rrPickerBuilder{}, base.Config{HealthCheck: true}) } func init() { diff --git a/clientconn.go b/clientconn.go index f49ac3f9b..5d43cf465 100644 --- a/clientconn.go +++ b/clientconn.go @@ -36,6 +36,7 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/transport" @@ -306,7 +307,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * break } else if cc.dopts.copts.FailOnNonTempDialError && s == connectivity.TransientFailure { if err = cc.blockingpicker.connectionError(); err != nil { - terr, ok := err.(interface{ Temporary() bool }) + terr, ok := err.(interface { + Temporary() bool + }) if ok && !terr.Temporary() { return nil, err } @@ -715,6 +718,12 @@ func (cc *ClientConn) GetMethodConfig(method string) MethodConfig { return m } +func (cc *ClientConn) healthCheckConfig() *healthCheckConfig { + cc.mu.RLock() + defer cc.mu.RUnlock() + return cc.sc.healthCheckConfig +} + func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, func(balancer.DoneInfo), error) { hdr, _ := metadata.FromOutgoingContext(ctx) t, done, err := cc.blockingpicker.pick(ctx, failfast, balancer.PickOptions{ @@ -877,6 +886,10 @@ type addrConn struct { acbw balancer.SubConn scopts balancer.NewSubConnOptions + // transport is set when there's a viable transport (note: ac state may not be READY as LB channel + // health checking may require server to report healthy to set ac to READY), and is reset + // to nil when the current transport should no longer be used to create a stream (e.g. after GoAway + // is received, transport is closed, ac has been torn down). transport transport.ClientTransport // The current transport. mu sync.Mutex @@ -903,6 +916,8 @@ type addrConn struct { czData *channelzData successfulHandshake bool + + healthCheckEnabled bool } // Note: this requires a lock on ac.mu. @@ -956,6 +971,8 @@ func (ac *addrConn) resetTransport(resolveNow bool) { return } + // The transport that was used before is no longer viable. + ac.transport = nil // If the connection is READY, a failure must have occurred. // Otherwise, we'll consider this is a transient failure when: // We've exhausted all addresses @@ -1044,7 +1061,10 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts var serverPrefaceReceived bool var clientPrefaceWrote bool + hcCtx, hcCancel := context.WithCancel(ac.ctx) + onGoAway := func(r transport.GoAwayReason) { + hcCancel() ac.mu.Lock() ac.adjustParams(r) ac.mu.Unlock() @@ -1059,6 +1079,7 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts prefaceTimer := time.NewTimer(connectDeadline.Sub(time.Now())) onClose := func() { + hcCancel() close(onCloseCalled) prefaceTimer.Stop() @@ -1166,22 +1187,46 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts return err } + // Now there is a viable transport to be use, so set ac.transport to reflect the new viable transport. + ac.mu.Lock() + if ac.state == connectivity.Shutdown { + ac.mu.Unlock() + close(skipReset) + newTr.Close() + return nil + } + ac.transport = newTr + ac.mu.Unlock() + + healthCheckConfig := ac.cc.healthCheckConfig() + // LB channel health checking is only enabled when all the four requirements below are met: + // 1. it is not disabled by the user with the WithDisableHealthCheck DialOption, + // 2. the internal.HealthCheckFunc is set by importing the grpc/healthcheck package, + // 3. a service config with non-empty healthCheckConfig field is provided, + // 4. the current load balancer allows it. + if !ac.cc.dopts.disableHealthCheck && healthCheckConfig != nil && ac.scopts.HealthCheckEnabled { + if internal.HealthCheckFunc != nil { + go ac.startHealthCheck(hcCtx, newTr, addr, healthCheckConfig.ServiceName) + close(allowedToReset) + return nil + } + // TODO: add a link to the health check doc in the error message. + grpclog.Error("the client side LB channel health check function has not been set.") + } + + // No LB channel health check case ac.mu.Lock() if ac.state == connectivity.Shutdown { ac.mu.Unlock() - // We don't want to reset during this close because we prefer to kick out of this function and let the loop - // in resetTransport take care of reconnecting. + // unblock onGoAway/onClose callback. close(skipReset) - - newTr.Close() return errConnClosing } ac.updateConnectivityState(connectivity.Ready) ac.cc.handleSubConnStateChange(ac.acbw, ac.state) - ac.transport = newTr ac.curAddr = addr ac.mu.Unlock() @@ -1192,6 +1237,51 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts return nil } +func (ac *addrConn) startHealthCheck(ctx context.Context, newTr transport.ClientTransport, addr resolver.Address, serviceName string) { + // Set up the health check helper functions + newStream := func() (interface{}, error) { + return ac.newClientStream(ctx, &StreamDesc{ServerStreams: true}, "/grpc.health.v1.Health/Watch", newTr) + } + firstReady := true + reportHealth := func(ok bool) { + ac.mu.Lock() + defer ac.mu.Unlock() + if ac.transport != newTr { + return + } + if ok { + if firstReady { + firstReady = false + ac.curAddr = addr + } + if ac.state != connectivity.Ready { + ac.updateConnectivityState(connectivity.Ready) + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) + } + } else { + if ac.state != connectivity.TransientFailure { + ac.updateConnectivityState(connectivity.TransientFailure) + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) + } + } + } + + err := internal.HealthCheckFunc(ctx, newStream, reportHealth, serviceName) + if err != nil { + if status.Code(err) == codes.Unimplemented { + if channelz.IsOn() { + channelz.AddTraceEvent(ac.channelzID, &channelz.TraceEventDesc{ + Desc: "Subchannel health check is unimplemented at server side, thus health check is disabled", + Severity: channelz.CtError, + }) + } + grpclog.Error("Subchannel health check is unimplemented at server side, thus health check is disabled") + } else { + grpclog.Errorf("HealthCheckFunc exits with unexpected error %v", err) + } + } +} + // nextAddr increments the addrIdx if there are more addresses to try. If // there are no more addrs to try it will re-resolve, set addrIdx to 0, and // increment the backoffIdx. @@ -1279,6 +1369,8 @@ func (ac *addrConn) tearDown(err error) { ac.mu.Unlock() return } + curTr := ac.transport + ac.transport = nil // We have to set the state to Shutdown before anything else to prevent races // between setting the state and logic that waits on context cancelation / etc. ac.updateConnectivityState(connectivity.Shutdown) @@ -1286,14 +1378,14 @@ func (ac *addrConn) tearDown(err error) { ac.tearDownErr = err ac.cc.handleSubConnStateChange(ac.acbw, ac.state) ac.curAddr = resolver.Address{} - if err == errConnDrain && ac.transport != nil { + if err == errConnDrain && curTr != nil { // GracefulClose(...) may be executed multiple times when // i) receiving multiple GoAway frames from the server; or // ii) there are concurrent name resolver/Balancer triggered // address removal and GoAway. // We have to unlock and re-lock here because GracefulClose => Close => onClose, which requires locking ac.mu. ac.mu.Unlock() - ac.transport.GracefulClose() + curTr.GracefulClose() ac.mu.Lock() } if channelz.IsOn() { diff --git a/dialoptions.go b/dialoptions.go index 99b495272..cc9ec56b9 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -59,6 +59,7 @@ type dialOptions struct { channelzParentID int64 disableServiceConfig bool disableRetry bool + disableHealthCheck bool } // DialOption configures how we set up the connection. @@ -454,6 +455,14 @@ func WithMaxHeaderListSize(s uint32) DialOption { }) } +// WithDisableHealthCheck disables the LB channel health checking for all SubConns of this ClientConn. +// +// This API is EXPERIMENTAL. +func WithDisableHealthCheck() DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.disableHealthCheck = true + }) +} func defaultDialOptions() dialOptions { return dialOptions{ disableRetry: !envconfig.Retry, diff --git a/rpc_util.go b/rpc_util.go index 51b952b08..b645998e3 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -155,7 +155,7 @@ func (d *gzipDecompressor) Type() string { type callInfo struct { compressorType string failFast bool - stream *clientStream + stream ClientStream maxReceiveMessageSize *int maxSendMessageSize *int creds credentials.PerRPCCredentials diff --git a/service_config.go b/service_config.go index a305fe0a4..162857e20 100644 --- a/service_config.go +++ b/service_config.go @@ -96,6 +96,15 @@ type ServiceConfig struct { // If token_count is less than or equal to maxTokens / 2, then RPCs will not // be retried and hedged RPCs will not be sent. retryThrottling *retryThrottlingPolicy + // healthCheckConfig must be set as one of the requirement to enable LB channel + // health check. + healthCheckConfig *healthCheckConfig +} + +// healthCheckConfig defines the go-native version of the LB channel health check config. +type healthCheckConfig struct { + // serviceName is the service name to use in the health-checking request. + ServiceName string } // retryPolicy defines the go-native version of the retry policy defined by the @@ -226,6 +235,7 @@ type jsonSC struct { LoadBalancingPolicy *string MethodConfig *[]jsonMC RetryThrottling *retryThrottlingPolicy + HealthCheckConfig *healthCheckConfig } func parseServiceConfig(js string) (ServiceConfig, error) { @@ -239,9 +249,10 @@ func parseServiceConfig(js string) (ServiceConfig, error) { return ServiceConfig{}, err } sc := ServiceConfig{ - LB: rsc.LoadBalancingPolicy, - Methods: make(map[string]MethodConfig), - retryThrottling: rsc.RetryThrottling, + LB: rsc.LoadBalancingPolicy, + Methods: make(map[string]MethodConfig), + retryThrottling: rsc.RetryThrottling, + healthCheckConfig: rsc.HealthCheckConfig, } if rsc.MethodConfig == nil { return sc, nil diff --git a/stream.go b/stream.go index 43a0372ba..9e217274f 100644 --- a/stream.go +++ b/stream.go @@ -26,6 +26,8 @@ import ( "sync" "time" + "google.golang.org/grpc/connectivity" + "golang.org/x/net/context" "golang.org/x/net/trace" "google.golang.org/grpc/balancer" @@ -950,6 +952,299 @@ func (a *csAttempt) finish(err error) { a.mu.Unlock() } +func (ac *addrConn) newClientStream(ctx context.Context, desc *StreamDesc, method string, t transport.ClientTransport, opts ...CallOption) (_ ClientStream, err error) { + ac.mu.Lock() + if ac.transport != t { + ac.mu.Unlock() + return nil, status.Error(codes.Canceled, "the provided transport is no longer valid to use") + } + // transition to CONNECTING state when an attempt starts + if ac.state != connectivity.Connecting { + ac.updateConnectivityState(connectivity.Connecting) + ac.cc.handleSubConnStateChange(ac.acbw, ac.state) + } + ac.mu.Unlock() + + if t == nil { + // TODO: return RPC error here? + return nil, errors.New("transport provided is nil") + } + // defaultCallInfo contains unnecessary info(i.e. failfast, maxRetryRPCBufferSize), so we just initialize an empty struct. + c := &callInfo{} + + for _, o := range opts { + if err := o.before(c); err != nil { + return nil, toRPCErr(err) + } + } + c.maxReceiveMessageSize = getMaxSize(nil, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize) + c.maxSendMessageSize = getMaxSize(nil, c.maxSendMessageSize, defaultServerMaxSendMessageSize) + + // Possible context leak: + // The cancel function for the child context we create will only be called + // when RecvMsg returns a non-nil error, if the ClientConn is closed, or if + // an error is generated by SendMsg. + // https://github.com/grpc/grpc-go/issues/1818. + ctx, cancel := context.WithCancel(ctx) + defer func() { + if err != nil { + cancel() + } + }() + + if err := setCallInfoCodec(c); err != nil { + return nil, err + } + + callHdr := &transport.CallHdr{ + Host: ac.cc.authority, + Method: method, + ContentSubtype: c.contentSubtype, + } + + // Set our outgoing compression according to the UseCompressor CallOption, if + // set. In that case, also find the compressor from the encoding package. + // Otherwise, use the compressor configured by the WithCompressor DialOption, + // if set. + var cp Compressor + var comp encoding.Compressor + if ct := c.compressorType; ct != "" { + callHdr.SendCompress = ct + if ct != encoding.Identity { + comp = encoding.GetCompressor(ct) + if comp == nil { + return nil, status.Errorf(codes.Internal, "grpc: Compressor is not installed for requested grpc-encoding %q", ct) + } + } + } else if ac.cc.dopts.cp != nil { + callHdr.SendCompress = ac.cc.dopts.cp.Type() + cp = ac.cc.dopts.cp + } + if c.creds != nil { + callHdr.Creds = c.creds + } + + as := &addrConnStream{ + callHdr: callHdr, + ac: ac, + ctx: ctx, + cancel: cancel, + opts: opts, + callInfo: c, + desc: desc, + codec: c.codec, + cp: cp, + comp: comp, + t: t, + } + + as.callInfo.stream = as + s, err := as.t.NewStream(as.ctx, as.callHdr) + if err != nil { + err = toRPCErr(err) + return nil, err + } + as.s = s + as.p = &parser{r: s} + ac.incrCallsStarted() + if desc != unaryStreamDesc { + // Listen on cc and stream contexts to cleanup when the user closes the + // ClientConn or cancels the stream context. In all other cases, an error + // should already be injected into the recv buffer by the transport, which + // the client will eventually receive, and then we will cancel the stream's + // context in clientStream.finish. + go func() { + select { + case <-ac.ctx.Done(): + as.finish(status.Error(codes.Canceled, "grpc: the SubConn is closing")) + case <-ctx.Done(): + as.finish(toRPCErr(ctx.Err())) + } + }() + } + return as, nil +} + +type addrConnStream struct { + s *transport.Stream + ac *addrConn + callHdr *transport.CallHdr + cancel context.CancelFunc + opts []CallOption + callInfo *callInfo + t transport.ClientTransport + ctx context.Context + sentLast bool + desc *StreamDesc + codec baseCodec + cp Compressor + comp encoding.Compressor + decompSet bool + dc Decompressor + decomp encoding.Compressor + p *parser + done func(balancer.DoneInfo) + mu sync.Mutex + finished bool +} + +func (as *addrConnStream) Header() (metadata.MD, error) { + m, err := as.s.Header() + if err != nil { + as.finish(toRPCErr(err)) + } + return m, err +} + +func (as *addrConnStream) Trailer() metadata.MD { + return as.s.Trailer() +} + +func (as *addrConnStream) CloseSend() error { + if as.sentLast { + // TODO: return an error and finish the stream instead, due to API misuse? + return nil + } + as.sentLast = true + + as.t.Write(as.s, nil, nil, &transport.Options{Last: true}) + // Always return nil; io.EOF is the only error that might make sense + // instead, but there is no need to signal the client to call RecvMsg + // as the only use left for the stream after CloseSend is to call + // RecvMsg. This also matches historical behavior. + return nil +} + +func (as *addrConnStream) Context() context.Context { + return as.s.Context() +} + +func (as *addrConnStream) SendMsg(m interface{}) (err error) { + defer func() { + if err != nil && err != io.EOF { + // Call finish on the client stream for errors generated by this SendMsg + // call, as these indicate problems created by this client. (Transport + // errors are converted to an io.EOF error in csAttempt.sendMsg; the real + // error will be returned from RecvMsg eventually in that case, or be + // retried.) + as.finish(err) + } + }() + if as.sentLast { + return status.Errorf(codes.Internal, "SendMsg called after CloseSend") + } + if !as.desc.ClientStreams { + as.sentLast = true + } + data, err := encode(as.codec, m) + if err != nil { + return err + } + compData, err := compress(data, as.cp, as.comp) + if err != nil { + return err + } + hdr, payld := msgHeader(data, compData) + // TODO(dfawley): should we be checking len(data) instead? + if len(payld) > *as.callInfo.maxSendMessageSize { + return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payld), *as.callInfo.maxSendMessageSize) + } + + if err := as.t.Write(as.s, hdr, payld, &transport.Options{Last: !as.desc.ClientStreams}); err != nil { + if !as.desc.ClientStreams { + // For non-client-streaming RPCs, we return nil instead of EOF on error + // because the generated code requires it. finish is not called; RecvMsg() + // will call it with the stream's status independently. + return nil + } + return io.EOF + } + + if channelz.IsOn() { + as.t.IncrMsgSent() + } + return nil +} + +func (as *addrConnStream) RecvMsg(m interface{}) (err error) { + defer func() { + if err != nil || !as.desc.ServerStreams { + // err != nil or non-server-streaming indicates end of stream. + as.finish(err) + } + }() + + if !as.decompSet { + // Block until we receive headers containing received message encoding. + if ct := as.s.RecvCompress(); ct != "" && ct != encoding.Identity { + if as.dc == nil || as.dc.Type() != ct { + // No configured decompressor, or it does not match the incoming + // message encoding; attempt to find a registered compressor that does. + as.dc = nil + as.decomp = encoding.GetCompressor(ct) + } + } else { + // No compression is used; disable our decompressor. + as.dc = nil + } + // Only initialize this state once per stream. + as.decompSet = true + } + err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp) + if err != nil { + if err == io.EOF { + if statusErr := as.s.Status().Err(); statusErr != nil { + return statusErr + } + return io.EOF // indicates successful end of stream. + } + return toRPCErr(err) + } + + if channelz.IsOn() { + as.t.IncrMsgRecv() + } + if as.desc.ServerStreams { + // Subsequent messages should be received by subsequent RecvMsg calls. + return nil + } + + // Special handling for non-server-stream rpcs. + // This recv expects EOF or errors, so we don't collect inPayload. + err = recv(as.p, as.codec, as.s, as.dc, m, *as.callInfo.maxReceiveMessageSize, nil, as.decomp) + if err == nil { + return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) + } + if err == io.EOF { + return as.s.Status().Err() // non-server streaming Recv returns nil on success + } + return toRPCErr(err) +} + +func (as *addrConnStream) finish(err error) { + as.mu.Lock() + if as.finished { + as.mu.Unlock() + return + } + as.finished = true + if err == io.EOF { + // Ending a stream with EOF indicates a success. + err = nil + } + if as.s != nil { + as.t.CloseStream(as.s, err) + } + + if err != nil { + as.ac.incrCallsFailed() + } else { + as.ac.incrCallsSucceeded() + } + as.cancel() + as.mu.Unlock() +} + // ServerStream defines the server-side behavior of a streaming RPC. // // All errors returned from ServerStream methods are compatible with the diff --git a/test/healthcheck_test.go b/test/healthcheck_test.go new file mode 100644 index 000000000..ebd5974a7 --- /dev/null +++ b/test/healthcheck_test.go @@ -0,0 +1,955 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "errors" + "fmt" + "net" + "sync" + "testing" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + _ "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/leakcheck" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +var testHealthCheckFunc = internal.HealthCheckFunc + +func replaceHealthCheckFunc(f func(context.Context, func() (interface{}, error), func(bool), string) error) func() { + oldHcFunc := internal.HealthCheckFunc + internal.HealthCheckFunc = f + return func() { + internal.HealthCheckFunc = oldHcFunc + } +} + +func newTestHealthServer() *testHealthServer { + return newTestHealthServerWithWatchFunc(defaultWatchFunc) +} + +func newTestHealthServerWithWatchFunc(f func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthpb.Health_WatchServer) error) *testHealthServer { + return &testHealthServer{ + watchFunc: f, + update: make(chan struct{}, 1), + status: make(map[string]healthpb.HealthCheckResponse_ServingStatus), + } +} + +// defaultWatchFunc will send a HealthCheckResponse to the client whenever SetServingStatus is called. +func defaultWatchFunc(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthpb.Health_WatchServer) error { + if in.Service != "foo" { + return status.Error(codes.FailedPrecondition, + "the defaultWatchFunc only handles request with service name to be \"foo\"") + } + var done bool + for { + select { + case <-stream.Context().Done(): + done = true + case <-s.update: + } + if done { + break + } + s.mu.Lock() + resp := &healthpb.HealthCheckResponse{ + Status: s.status[in.Service], + } + s.mu.Unlock() + stream.SendMsg(resp) + } + return nil +} + +type testHealthServer struct { + watchFunc func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthpb.Health_WatchServer) error + mu sync.Mutex + status map[string]healthpb.HealthCheckResponse_ServingStatus + update chan struct{} +} + +func (s *testHealthServer) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + return &healthpb.HealthCheckResponse{ + Status: healthpb.HealthCheckResponse_SERVING, + }, nil +} + +func (s *testHealthServer) Watch(in *healthpb.HealthCheckRequest, stream healthpb.Health_WatchServer) error { + return s.watchFunc(s, in, stream) +} + +// SetServingStatus is called when need to reset the serving status of a service +// or insert a new service entry into the statusMap. +func (s *testHealthServer) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) { + s.mu.Lock() + s.status[service] = status + select { + case <-s.update: + default: + } + s.update <- struct{}{} + s.mu.Unlock() +} + +func TestHealthCheckWatchStateChange(t *testing.T) { + defer leakcheck.Check(t) + s := grpc.NewServer() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen due to err: %v", err) + } + ts := newTestHealthServer() + healthpb.RegisterHealthServer(s, ts) + go s.Serve(lis) + defer s.Stop() + + // The table below shows the expected series of addrConn connectivity transitions when server + // updates its health status. As there's only one addrConn corresponds with the ClientConn in this + // test, we use ClientConn's connectivity state as the addrConn connectivity state. + //+------------------------------+-------------------------------------------+ + //| Health Check Returned Status | Expected addrConn Connectivity Transition | + //+------------------------------+-------------------------------------------+ + //| NOT_SERVING | ->TRANSIENT FAILURE | + //| SERVING | ->READY | + //| SERVICE_UNKNOWN | ->TRANSIENT FAILURE | + //| SERVING | ->READY | + //| UNKNOWN | ->TRANSIENT FAILURE | + //+------------------------------+-------------------------------------------+ + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_NOT_SERVING) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("round_robin")) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + defer cc.Close() + + r.NewServiceConfig(`{ + "healthCheckConfig": { + "serviceName": "foo" + } +}`) + r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}}) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if ok := cc.WaitForStateChange(ctx, connectivity.Idle); !ok { + t.Fatal("ClientConn is still in IDLE state when the context times out.") + } + if ok := cc.WaitForStateChange(ctx, connectivity.Connecting); !ok { + t.Fatal("ClientConn is still in CONNECTING state when the context times out.") + } + if s := cc.GetState(); s != connectivity.TransientFailure { + t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s) + } + + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING) + if ok := cc.WaitForStateChange(ctx, connectivity.TransientFailure); !ok { + t.Fatal("ClientConn is still in TRANSIENT FAILURE state when the context times out.") + } + if s := cc.GetState(); s != connectivity.Ready { + t.Fatalf("ClientConn is in %v state, want READY", s) + } + + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVICE_UNKNOWN) + if ok := cc.WaitForStateChange(ctx, connectivity.Ready); !ok { + t.Fatal("ClientConn is still in READY state when the context times out.") + } + if s := cc.GetState(); s != connectivity.TransientFailure { + t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s) + } + + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING) + if ok := cc.WaitForStateChange(ctx, connectivity.TransientFailure); !ok { + t.Fatal("ClientConn is still in TRANSIENT FAILURE state when the context times out.") + } + if s := cc.GetState(); s != connectivity.Ready { + t.Fatalf("ClientConn is in %v state, want READY", s) + } + + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_UNKNOWN) + if ok := cc.WaitForStateChange(ctx, connectivity.Ready); !ok { + t.Fatal("ClientConn is still in READY state when the context times out.") + } + if s := cc.GetState(); s != connectivity.TransientFailure { + t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s) + } +} + +// In the case of a goaway received, the health check stream should be terminated and health check +// function should exit. +func TestHealthCheckWithGoAway(t *testing.T) { + defer leakcheck.Check(t) + s := grpc.NewServer() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen due to err: %v", err) + } + ts := newTestHealthServer() + healthpb.RegisterHealthServer(s, ts) + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + defer s.Stop() + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING) + hcExitChan := make(chan struct{}) + testHealthCheckFuncWrapper := func(ctx context.Context, newStream func() (interface{}, error), update func(bool), service string) error { + err := testHealthCheckFunc(ctx, newStream, update, service) + close(hcExitChan) + return err + } + replace := replaceHealthCheckFunc(testHealthCheckFuncWrapper) + defer replace() + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("round_robin")) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + defer cc.Close() + + tc := testpb.NewTestServiceClient(cc) + r.NewServiceConfig(`{ + "healthCheckConfig": { + "serviceName": "foo" + } +}`) + r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}}) + + // make some rpcs to make sure connection is working. + if err := verifyResultWithDelay(func() (bool, error) { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + // the stream rpc will persist through goaway event. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + respParam := []*testpb.ResponseParameters{{Size: 1}} + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1)) + if err != nil { + t.Fatal(err) + } + req := &testpb.StreamingOutputCallRequest{ + ResponseParameters: respParam, + Payload: payload, + } + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(_) = %v, want ", stream, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) + } + + select { + case <-hcExitChan: + t.Fatal("Health check function has exited, which is not expected.") + default: + } + + // server sends GoAway + go s.GracefulStop() + + select { + case <-hcExitChan: + case <-time.After(5 * time.Second): + t.Fatal("Health check function has not exited after 5s.") + } + + // The existing RPC should be still good to proceed. + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(_) = %v, want ", stream, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) + } +} + +func TestHealthCheckWithConnClose(t *testing.T) { + defer leakcheck.Check(t) + s := grpc.NewServer() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen due to err: %v", err) + } + ts := newTestHealthServer() + healthpb.RegisterHealthServer(s, ts) + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + defer s.Stop() + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING) + hcExitChan := make(chan struct{}) + testHealthCheckFuncWrapper := func(ctx context.Context, newStream func() (interface{}, error), update func(bool), service string) error { + err := testHealthCheckFunc(ctx, newStream, update, service) + close(hcExitChan) + return err + } + + replace := replaceHealthCheckFunc(testHealthCheckFuncWrapper) + defer replace() + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("round_robin")) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + defer cc.Close() + tc := testpb.NewTestServiceClient(cc) + + r.NewServiceConfig(`{ + "healthCheckConfig": { + "serviceName": "foo" + } +}`) + r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}}) + + // make some rpcs to make sure connection is working. + if err := verifyResultWithDelay(func() (bool, error) { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + select { + case <-hcExitChan: + t.Fatal("Health check function has exited, which is not expected.") + default: + } + // server closes the connection + s.Stop() + + select { + case <-hcExitChan: + case <-time.After(5 * time.Second): + t.Fatal("Health check function has not exited after 5s.") + } +} + +// addrConn drain happens when addrConn gets torn down due to its address being no longer in the +// address list returned by the resolver. +func TestHealthCheckWithAddrConnDrain(t *testing.T) { + defer leakcheck.Check(t) + s := grpc.NewServer() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen due to err: %v", err) + } + ts := newTestHealthServer() + healthpb.RegisterHealthServer(s, ts) + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + defer s.Stop() + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING) + hcExitChan := make(chan struct{}) + testHealthCheckFuncWrapper := func(ctx context.Context, newStream func() (interface{}, error), update func(bool), service string) error { + err := testHealthCheckFunc(ctx, newStream, update, service) + close(hcExitChan) + return err + } + + replace := replaceHealthCheckFunc(testHealthCheckFuncWrapper) + defer replace() + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("round_robin")) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + defer cc.Close() + + tc := testpb.NewTestServiceClient(cc) + r.NewServiceConfig(`{ + "healthCheckConfig": { + "serviceName": "foo" + } +}`) + r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}}) + + // make some rpcs to make sure connection is working. + if err := verifyResultWithDelay(func() (bool, error) { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + // the stream rpc will persist through goaway event. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + respParam := []*testpb.ResponseParameters{{Size: 1}} + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1)) + if err != nil { + t.Fatal(err) + } + req := &testpb.StreamingOutputCallRequest{ + ResponseParameters: respParam, + Payload: payload, + } + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(_) = %v, want ", stream, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) + } + + select { + case <-hcExitChan: + t.Fatal("Health check function has exited, which is not expected.") + default: + } + // trigger teardown of the ac + r.NewAddress([]resolver.Address{}) + + select { + case <-hcExitChan: + case <-time.After(5 * time.Second): + t.Fatal("Health check function has not exited after 5s.") + } + + // The existing RPC should be still good to proceed. + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(_) = %v, want ", stream, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) + } +} + +// ClientConn close will lead to its addrConns being torn down. +func TestHealthCheckWithClientConnClose(t *testing.T) { + defer leakcheck.Check(t) + s := grpc.NewServer() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen due to err: %v", err) + } + ts := newTestHealthServer() + healthpb.RegisterHealthServer(s, ts) + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + defer s.Stop() + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING) + hcExitChan := make(chan struct{}) + testHealthCheckFuncWrapper := func(ctx context.Context, newStream func() (interface{}, error), update func(bool), service string) error { + err := testHealthCheckFunc(ctx, newStream, update, service) + close(hcExitChan) + return err + } + + replace := replaceHealthCheckFunc(testHealthCheckFuncWrapper) + defer replace() + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("round_robin")) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + defer cc.Close() + + tc := testpb.NewTestServiceClient(cc) + r.NewServiceConfig(`{ + "healthCheckConfig": { + "serviceName": "foo" + } +}`) + r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}}) + + // make some rpcs to make sure connection is working. + if err := verifyResultWithDelay(func() (bool, error) { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + select { + case <-hcExitChan: + t.Fatal("Health check function has exited, which is not expected.") + default: + } + + // trigger addrConn teardown + cc.Close() + + select { + case <-hcExitChan: + case <-time.After(5 * time.Second): + t.Fatal("Health check function has not exited after 5s.") + } +} + +// This test is to test the logic in the createTransport after the health check function returns which +// closes the skipReset channel(since it has not been closed inside health check func) to unblock +// onGoAway/onClose goroutine. +func TestHealthCheckWithoutReportHealthCalledAddrConnShutDown(t *testing.T) { + defer leakcheck.Check(t) + s := grpc.NewServer() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen due to err %v", err) + } + ts := newTestHealthServerWithWatchFunc(func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthpb.Health_WatchServer) error { + if in.Service != "delay" { + return status.Error(codes.FailedPrecondition, + "this special Watch function only handles request with service name to be \"delay\"") + } + // Do nothing to mock a delay of health check response from server side. + // This case is to help with the test that covers the condition that reportHealth is not + // called inside HealthCheckFunc before the func returns. + select { + case <-stream.Context().Done(): + case <-time.After(5 * time.Second): + } + return nil + }) + healthpb.RegisterHealthServer(s, ts) + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + defer s.Stop() + ts.SetServingStatus("delay", healthpb.HealthCheckResponse_SERVING) + + hcEnterChan := make(chan struct{}) + hcExitChan := make(chan struct{}) + testHealthCheckFuncWrapper := func(ctx context.Context, newStream func() (interface{}, error), update func(bool), service string) error { + close(hcEnterChan) + err := testHealthCheckFunc(ctx, newStream, update, service) + close(hcExitChan) + return err + } + + replace := replaceHealthCheckFunc(testHealthCheckFuncWrapper) + defer replace() + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("round_robin")) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + defer cc.Close() + + // The serviceName "delay" is specially handled at server side, where response will not be sent + // back to client immediately upon receiving the request (client should receive no response until + // test ends). + r.NewServiceConfig(`{ + "healthCheckConfig": { + "serviceName": "delay" + } +}`) + r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}}) + + select { + case <-hcExitChan: + t.Fatal("Health check function has exited, which is not expected.") + default: + } + + select { + case <-hcEnterChan: + case <-time.After(5 * time.Second): + t.Fatal("Health check function has not been invoked after 5s.") + } + // trigger teardown of the ac, ac in SHUTDOWN state + r.NewAddress([]resolver.Address{}) + + // The health check func should exit without calling the reportHealth func, as server hasn't sent + // any response. + select { + case <-hcExitChan: + case <-time.After(5 * time.Second): + t.Fatal("Health check function has not exited after 5s.") + } + // The deferred leakcheck will check whether there's leaked goroutine, which is an indication + // whether we closes the skipReset channel to unblock onGoAway/onClose goroutine. +} + +// This test is to test the logic in the createTransport after the health check function returns which +// closes the allowedToReset channel(since it has not been closed inside health check func) to unblock +// onGoAway/onClose goroutine. +func TestHealthCheckWithoutReportHealthCalled(t *testing.T) { + defer leakcheck.Check(t) + s := grpc.NewServer() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen due to err: %v", err) + } + ts := newTestHealthServerWithWatchFunc(func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthpb.Health_WatchServer) error { + if in.Service != "delay" { + return status.Error(codes.FailedPrecondition, + "this special Watch function only handles request with service name to be \"delay\"") + } + // Do nothing to mock a delay of health check response from server side. + // This case is to help with the test that covers the condition that reportHealth is not + // called inside HealthCheckFunc before the func returns. + select { + case <-stream.Context().Done(): + case <-time.After(5 * time.Second): + } + return nil + }) + healthpb.RegisterHealthServer(s, ts) + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + defer s.Stop() + ts.SetServingStatus("delay", healthpb.HealthCheckResponse_SERVING) + + hcEnterChan := make(chan struct{}) + hcExitChan := make(chan struct{}) + testHealthCheckFuncWrapper := func(ctx context.Context, newStream func() (interface{}, error), update func(bool), service string) error { + close(hcEnterChan) + err := testHealthCheckFunc(ctx, newStream, update, service) + close(hcExitChan) + return err + } + + replace := replaceHealthCheckFunc(testHealthCheckFuncWrapper) + defer replace() + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("round_robin")) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + defer cc.Close() + + // The serviceName "delay" is specially handled at server side, where response will not be sent + // back to client immediately upon receiving the request (client should receive no response until + // test ends). + r.NewServiceConfig(`{ + "healthCheckConfig": { + "serviceName": "delay" + } +}`) + r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}}) + + select { + case <-hcExitChan: + t.Fatal("Health check function has exited, which is not expected.") + default: + } + + select { + case <-hcEnterChan: + case <-time.After(5 * time.Second): + t.Fatal("Health check function has not been invoked after 5s.") + } + // trigger transport being closed + s.Stop() + + // The health check func should exit without calling the reportHealth func, as server hasn't sent + // any response. + select { + case <-hcExitChan: + case <-time.After(5 * time.Second): + t.Fatal("Health check function has not exited after 5s.") + } + // The deferred leakcheck will check whether there's leaked goroutine, which is an indication + // whether we closes the allowedToReset channel to unblock onGoAway/onClose goroutine. +} + +func testHealthCheckDisableWithDialOption(t *testing.T, addr string) { + hcEnterChan := make(chan struct{}) + testHealthCheckFuncWrapper := func(ctx context.Context, newStream func() (interface{}, error), update func(bool), service string) error { + close(hcEnterChan) + return nil + } + + replace := replaceHealthCheckFunc(testHealthCheckFuncWrapper) + defer replace() + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("round_robin"), grpc.WithDisableHealthCheck()) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + tc := testpb.NewTestServiceClient(cc) + defer cc.Close() + r.NewServiceConfig(`{ + "healthCheckConfig": { + "serviceName": "foo" + } +}`) + r.NewAddress([]resolver.Address{{Addr: addr}}) + + // send some rpcs to make sure transport has been created and is ready for use. + if err := verifyResultWithDelay(func() (bool, error) { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + select { + case <-hcEnterChan: + t.Fatal("Health check function has exited, which is not expected.") + default: + } +} + +func testHealthCheckDisableWithBalancer(t *testing.T, addr string) { + hcEnterChan := make(chan struct{}) + testHealthCheckFuncWrapper := func(ctx context.Context, newStream func() (interface{}, error), update func(bool), service string) error { + close(hcEnterChan) + return nil + } + + replace := replaceHealthCheckFunc(testHealthCheckFuncWrapper) + defer replace() + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("pick_first")) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + tc := testpb.NewTestServiceClient(cc) + defer cc.Close() + r.NewServiceConfig(`{ + "healthCheckConfig": { + "serviceName": "foo" + } +}`) + r.NewAddress([]resolver.Address{{Addr: addr}}) + + // send some rpcs to make sure transport has been created and is ready for use. + if err := verifyResultWithDelay(func() (bool, error) { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + select { + case <-hcEnterChan: + t.Fatal("Health check function has started, which is not expected.") + default: + } +} + +func testHealthCheckDisableWithServiceConfig(t *testing.T, addr string) { + hcEnterChan := make(chan struct{}) + testHealthCheckFuncWrapper := func(ctx context.Context, newStream func() (interface{}, error), update func(bool), service string) error { + close(hcEnterChan) + return nil + } + + replace := replaceHealthCheckFunc(testHealthCheckFuncWrapper) + defer replace() + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("round_robin")) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + tc := testpb.NewTestServiceClient(cc) + defer cc.Close() + + r.NewAddress([]resolver.Address{{Addr: addr}}) + + // send some rpcs to make sure transport has been created and is ready for use. + if err := verifyResultWithDelay(func() (bool, error) { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + select { + case <-hcEnterChan: + t.Fatal("Health check function has started, which is not expected.") + default: + } +} + +func TestHealthCheckDisable(t *testing.T) { + defer leakcheck.Check(t) + // set up server side + s := grpc.NewServer() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen due to err: %v", err) + } + ts := newTestHealthServer() + healthpb.RegisterHealthServer(s, ts) + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + defer s.Stop() + ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING) + + // test client side disabling configuration. + testHealthCheckDisableWithDialOption(t, lis.Addr().String()) + testHealthCheckDisableWithBalancer(t, lis.Addr().String()) + testHealthCheckDisableWithServiceConfig(t, lis.Addr().String()) +} + +func TestHealthCheckChannelzCountingCallSuccess(t *testing.T) { + defer leakcheck.Check(t) + s := grpc.NewServer() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen due to err: %v", err) + } + ts := newTestHealthServerWithWatchFunc(func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthpb.Health_WatchServer) error { + if in.Service != "channelzSuccess" { + return status.Error(codes.FailedPrecondition, + "this special Watch function only handles request with service name to be \"channelzSuccess\"") + } + return status.Error(codes.OK, "fake success") + }) + healthpb.RegisterHealthServer(s, ts) + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + defer s.Stop() + + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("round_robin")) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + defer cc.Close() + + r.NewServiceConfig(`{ + "healthCheckConfig": { + "serviceName": "channelzSuccess" + } +}`) + r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}}) + + if err := verifyResultWithDelay(func() (bool, error) { + cm, _ := channelz.GetTopChannels(0) + if len(cm) == 0 { + return false, errors.New("channelz.GetTopChannels return 0 top channel") + } + if len(cm[0].SubChans) == 0 { + return false, errors.New("there is 0 subchannel") + } + var id int64 + for k := range cm[0].SubChans { + id = k + break + } + scm := channelz.GetSubChannel(id) + if scm == nil || scm.ChannelData == nil { + return false, errors.New("nil subchannel metric or nil subchannel metric ChannelData returned") + } + // exponential backoff retry may result in more than one health check call. + if scm.ChannelData.CallsStarted > 0 && scm.ChannelData.CallsSucceeded > 0 && scm.ChannelData.CallsFailed == 0 { + return true, nil + } + return false, fmt.Errorf("got %d CallsStarted, %d CallsSucceeded, want >0 >0", scm.ChannelData.CallsStarted, scm.ChannelData.CallsSucceeded) + }); err != nil { + t.Fatal(err) + } +} + +func TestHealthCheckChannelzCountingCallFailure(t *testing.T) { + defer leakcheck.Check(t) + s := grpc.NewServer() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen due to err: %v", err) + } + ts := newTestHealthServerWithWatchFunc(func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthpb.Health_WatchServer) error { + if in.Service != "channelzFailure" { + return status.Error(codes.FailedPrecondition, + "this special Watch function only handles request with service name to be \"channelzFailure\"") + } + return status.Error(codes.Internal, "fake failure") + }) + healthpb.RegisterHealthServer(s, ts) + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + defer s.Stop() + + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName("round_robin")) + if err != nil { + t.Fatalf("dial failed due to err: %v", err) + } + defer cc.Close() + + r.NewServiceConfig(`{ + "healthCheckConfig": { + "serviceName": "channelzFailure" + } +}`) + r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}}) + + if err := verifyResultWithDelay(func() (bool, error) { + cm, _ := channelz.GetTopChannels(0) + if len(cm) == 0 { + return false, errors.New("channelz.GetTopChannels return 0 top channel") + } + if len(cm[0].SubChans) == 0 { + return false, errors.New("there is 0 subchannel") + } + var id int64 + for k := range cm[0].SubChans { + id = k + break + } + scm := channelz.GetSubChannel(id) + if scm == nil || scm.ChannelData == nil { + return false, errors.New("nil subchannel metric or nil subchannel metric ChannelData returned") + } + // exponential backoff retry may result in more than one health check call. + if scm.ChannelData.CallsStarted > 0 && scm.ChannelData.CallsFailed > 0 && scm.ChannelData.CallsSucceeded == 0 { + return true, nil + } + return false, fmt.Errorf("got %d CallsStarted, %d CallsFailed, want >0, >0", scm.ChannelData.CallsStarted, scm.ChannelData.CallsFailed) + }); err != nil { + t.Fatal(err) + } +}