diff --git a/benchmark/benchmain/main.go b/benchmark/benchmain/main.go index a62bb9300..5a9e478d2 100644 --- a/benchmark/benchmain/main.go +++ b/benchmark/benchmain/main.go @@ -101,6 +101,7 @@ var ( memProfile, cpuProfile string memProfileRate int enableCompressor []bool + enableChannelz []bool networkMode string benchmarkResultFile string networks = map[string]latency.Network{ @@ -283,15 +284,16 @@ var useBufconn = flag.Bool("bufconn", false, "Use in-memory connection instead o // Initiate main function to get settings of features. func init() { var ( - workloads, traceMode, compressorMode, readLatency string - readKbps, readMtu, readMaxConcurrentCalls intSliceType - readReqSizeBytes, readRespSizeBytes intSliceType + workloads, traceMode, compressorMode, readLatency, channelzOn string + readKbps, readMtu, readMaxConcurrentCalls intSliceType + readReqSizeBytes, readRespSizeBytes intSliceType ) flag.StringVar(&workloads, "workloads", workloadsAll, fmt.Sprintf("Workloads to execute - One of: %v", strings.Join(allWorkloads, ", "))) flag.StringVar(&traceMode, "trace", modeOff, fmt.Sprintf("Trace mode - One of: %v", strings.Join(allTraceModes, ", "))) flag.StringVar(&readLatency, "latency", "", "Simulated one-way network latency - may be a comma-separated list") + flag.StringVar(&channelzOn, "channelz", modeOff, "whether channelz should be turned on") flag.DurationVar(&benchtime, "benchtime", time.Second, "Configures the amount of time to run each benchmark") flag.Var(&readKbps, "kbps", "Simulated network throughput (in kbps) - may be a comma-separated list") flag.Var(&readMtu, "mtu", "Simulated network MTU (Maximum Transmission Unit) - may be a comma-separated list") @@ -327,6 +329,7 @@ func init() { } enableCompressor = setMode(compressorMode) enableTrace = setMode(traceMode) + enableChannelz = setMode(channelzOn) // Time input formats as (time + unit). readTimeFromInput(<c, readLatency) readIntFromIntSlice(&kbps, readKbps) @@ -400,10 +403,10 @@ func readTimeFromInput(values *[]time.Duration, replace string) { func main() { before() - featuresPos := make([]int, 8) + featuresPos := make([]int, 9) // 0:enableTracing 1:ltc 2:kbps 3:mtu 4:maxC 5:reqSize 6:respSize featuresNum := []int{len(enableTrace), len(ltc), len(kbps), len(mtu), - len(maxConcurrentCalls), len(reqSizeBytes), len(respSizeBytes), len(enableCompressor)} + len(maxConcurrentCalls), len(reqSizeBytes), len(respSizeBytes), len(enableCompressor), len(enableChannelz)} initalPos := make([]int, len(featuresPos)) s := stats.NewStats(10) s.SortLatency() @@ -444,9 +447,13 @@ func main() { ReqSizeBytes: reqSizeBytes[featuresPos[5]], RespSizeBytes: respSizeBytes[featuresPos[6]], EnableCompressor: enableCompressor[featuresPos[7]], + EnableChannelz: enableChannelz[featuresPos[8]], } grpc.EnableTracing = enableTrace[featuresPos[0]] + if enableChannelz[featuresPos[8]] { + grpc.RegisterChannelz() + } if runMode[0] { unaryBenchmark(startTimer, stopTimer, benchFeature, benchtime, s) s.SetBenchmarkResult("Unary", benchFeature, results.N, diff --git a/benchmark/benchmark16_test.go b/benchmark/benchmark16_test.go index fc33e80d0..a036b63cb 100644 --- a/benchmark/benchmark16_test.go +++ b/benchmark/benchmark16_test.go @@ -30,81 +30,81 @@ import ( func BenchmarkClientStreamc1(b *testing.B) { grpc.EnableTracing = true - runStream(b, stats.Features{"", true, 0, 0, 0, 1, 1, 1, false}) + runStream(b, stats.Features{"", true, 0, 0, 0, 1, 1, 1, false, false}) } func BenchmarkClientStreamc8(b *testing.B) { grpc.EnableTracing = true - runStream(b, stats.Features{"", true, 0, 0, 0, 8, 1, 1, false}) + runStream(b, stats.Features{"", true, 0, 0, 0, 8, 1, 1, false, false}) } func BenchmarkClientStreamc64(b *testing.B) { grpc.EnableTracing = true - runStream(b, stats.Features{"", true, 0, 0, 0, 64, 1, 1, false}) + runStream(b, stats.Features{"", true, 0, 0, 0, 64, 1, 1, false, false}) } func BenchmarkClientStreamc512(b *testing.B) { grpc.EnableTracing = true - runStream(b, stats.Features{"", true, 0, 0, 0, 512, 1, 1, false}) + runStream(b, stats.Features{"", true, 0, 0, 0, 512, 1, 1, false, false}) } func BenchmarkClientUnaryc1(b *testing.B) { grpc.EnableTracing = true - runStream(b, stats.Features{"", true, 0, 0, 0, 1, 1, 1, false}) + runStream(b, stats.Features{"", true, 0, 0, 0, 1, 1, 1, false, false}) } func BenchmarkClientUnaryc8(b *testing.B) { grpc.EnableTracing = true - runStream(b, stats.Features{"", true, 0, 0, 0, 8, 1, 1, false}) + runStream(b, stats.Features{"", true, 0, 0, 0, 8, 1, 1, false, false}) } func BenchmarkClientUnaryc64(b *testing.B) { grpc.EnableTracing = true - runStream(b, stats.Features{"", true, 0, 0, 0, 64, 1, 1, false}) + runStream(b, stats.Features{"", true, 0, 0, 0, 64, 1, 1, false, false}) } func BenchmarkClientUnaryc512(b *testing.B) { grpc.EnableTracing = true - runStream(b, stats.Features{"", true, 0, 0, 0, 512, 1, 1, false}) + runStream(b, stats.Features{"", true, 0, 0, 0, 512, 1, 1, false, false}) } func BenchmarkClientStreamNoTracec1(b *testing.B) { grpc.EnableTracing = false - runStream(b, stats.Features{"", false, 0, 0, 0, 1, 1, 1, false}) + runStream(b, stats.Features{"", false, 0, 0, 0, 1, 1, 1, false, false}) } func BenchmarkClientStreamNoTracec8(b *testing.B) { grpc.EnableTracing = false - runStream(b, stats.Features{"", false, 0, 0, 0, 8, 1, 1, false}) + runStream(b, stats.Features{"", false, 0, 0, 0, 8, 1, 1, false, false}) } func BenchmarkClientStreamNoTracec64(b *testing.B) { grpc.EnableTracing = false - runStream(b, stats.Features{"", false, 0, 0, 0, 64, 1, 1, false}) + runStream(b, stats.Features{"", false, 0, 0, 0, 64, 1, 1, false, false}) } func BenchmarkClientStreamNoTracec512(b *testing.B) { grpc.EnableTracing = false - runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false}) + runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false, false}) } func BenchmarkClientUnaryNoTracec1(b *testing.B) { grpc.EnableTracing = false - runStream(b, stats.Features{"", false, 0, 0, 0, 1, 1, 1, false}) + runStream(b, stats.Features{"", false, 0, 0, 0, 1, 1, 1, false, false}) } func BenchmarkClientUnaryNoTracec8(b *testing.B) { grpc.EnableTracing = false - runStream(b, stats.Features{"", false, 0, 0, 0, 8, 1, 1, false}) + runStream(b, stats.Features{"", false, 0, 0, 0, 8, 1, 1, false, false}) } func BenchmarkClientUnaryNoTracec64(b *testing.B) { grpc.EnableTracing = false - runStream(b, stats.Features{"", false, 0, 0, 0, 64, 1, 1, false}) + runStream(b, stats.Features{"", false, 0, 0, 0, 64, 1, 1, false, false}) } func BenchmarkClientUnaryNoTracec512(b *testing.B) { grpc.EnableTracing = false - runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false}) - runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false}) + runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false, false}) + runStream(b, stats.Features{"", false, 0, 0, 0, 512, 1, 1, false, false}) } func TestMain(m *testing.M) { diff --git a/benchmark/stats/stats.go b/benchmark/stats/stats.go index 2f98861c9..041d2475c 100644 --- a/benchmark/stats/stats.go +++ b/benchmark/stats/stats.go @@ -39,6 +39,7 @@ type Features struct { ReqSizeBytes int RespSizeBytes int EnableCompressor bool + EnableChannelz bool } // String returns the textual output of the Features as string. @@ -48,6 +49,13 @@ func (f Features) String() string { f.Latency.String(), f.Kbps, f.Mtu, f.MaxConcurrentCalls, f.ReqSizeBytes, f.RespSizeBytes, f.EnableCompressor) } +// ConciseString returns the concise textual output of the Features as string, skipping +// setting with default value. +func (f Features) ConciseString() string { + noneEmptyPos := []bool{f.EnableTrace, f.Latency != 0, f.Kbps != 0, f.Mtu != 0, true, true, true, f.EnableCompressor, f.EnableChannelz} + return PartialPrintString(noneEmptyPos, f, false) +} + // PartialPrintString can print certain features with different format. func PartialPrintString(noneEmptyPos []bool, f Features, shared bool) string { s := "" @@ -63,7 +71,7 @@ func PartialPrintString(noneEmptyPos []bool, f Features, shared bool) string { linker = "_" } if noneEmptyPos[0] { - s += fmt.Sprintf("%sTrace%s%t%s", prefix, linker, f.EnableCompressor, suffix) + s += fmt.Sprintf("%sTrace%s%t%s", prefix, linker, f.EnableTrace, suffix) } if shared && f.NetworkMode != "" { s += fmt.Sprintf("Network: %s \n", f.NetworkMode) @@ -92,6 +100,9 @@ func PartialPrintString(noneEmptyPos []bool, f Features, shared bool) string { if noneEmptyPos[7] { s += fmt.Sprintf("%sCompressor%s%t%s", prefix, linker, f.EnableCompressor, suffix) } + if noneEmptyPos[8] { + s += fmt.Sprintf("%sChannelz%s%t%s", prefix, linker, f.EnableChannelz, suffix) + } return s } diff --git a/channelz/types.go b/channelz/types.go index b84abfe86..153d75340 100644 --- a/channelz/types.go +++ b/channelz/types.go @@ -243,10 +243,13 @@ type SocketMetric struct { type SocketInternalMetric struct { // The number of streams that have been started. StreamsStarted int64 - // The number of streams that have ended successfully with the EoS bit set for - // both end points. + // The number of streams that have ended successfully: + // On client side, receiving frame with eos bit set. + // On server side, sending frame with eos bit set. StreamsSucceeded int64 - // The number of incoming streams that have a completed with a non-OK status. + // The number of streams that have ended unsuccessfully: + // On client side, termination without receiving frame with eos bit set. + // On server side, termination without sending frame with eos bit set. StreamsFailed int64 // The number of messages successfully sent on this socket. MessagesSent int64 diff --git a/clientconn.go b/clientconn.go index b3b60996e..70e5b8025 100644 --- a/clientconn.go +++ b/clientconn.go @@ -675,7 +675,12 @@ type ClientConn struct { curAddresses []resolver.Address balancerWrapper *ccBalancerWrapper - channelzID int64 // channelz unique identification number + channelzID int64 // channelz unique identification number + czmu sync.RWMutex + callsStarted int64 + callsSucceeded int64 + callsFailed int64 + lastCallStartedTime time.Time } // WaitForStateChange waits until the connectivity.State of ClientConn changes from sourceState or @@ -863,7 +868,37 @@ func (cc *ClientConn) removeAddrConn(ac *addrConn, err error) { // ChannelzMetric returns ChannelInternalMetric of current ClientConn. // This is an EXPERIMENTAL API. func (cc *ClientConn) ChannelzMetric() *channelz.ChannelInternalMetric { - return &channelz.ChannelInternalMetric{} + state := cc.GetState() + cc.czmu.RLock() + defer cc.czmu.RUnlock() + return &channelz.ChannelInternalMetric{ + State: state, + Target: cc.target, + CallsStarted: cc.callsStarted, + CallsSucceeded: cc.callsSucceeded, + CallsFailed: cc.callsFailed, + LastCallStartedTimestamp: cc.lastCallStartedTime, + } +} + +func (cc *ClientConn) incrCallsStarted() { + cc.czmu.Lock() + cc.callsStarted++ + // TODO(yuxuanli): will make this a time.Time pointer improve performance? + cc.lastCallStartedTime = time.Now() + cc.czmu.Unlock() +} + +func (cc *ClientConn) incrCallsSucceeded() { + cc.czmu.Lock() + cc.callsSucceeded++ + cc.czmu.Unlock() +} + +func (cc *ClientConn) incrCallsFailed() { + cc.czmu.Lock() + cc.callsFailed++ + cc.czmu.Unlock() } // connect starts to creating transport and also starts the transport monitor @@ -1013,13 +1048,16 @@ func (cc *ClientConn) Close() error { bWrapper := cc.balancerWrapper cc.balancerWrapper = nil cc.mu.Unlock() + cc.blockingpicker.close() + if rWrapper != nil { rWrapper.close() } if bWrapper != nil { bWrapper.close() } + for ac := range conns { ac.tearDown(ErrClientConnClosing) } @@ -1060,7 +1098,12 @@ type addrConn struct { // negotiations must complete. connectDeadline time.Time - channelzID int64 // channelz unique identification number + channelzID int64 // channelz unique identification number + czmu sync.RWMutex + callsStarted int64 + callsSucceeded int64 + callsFailed int64 + lastCallStartedTime time.Time } // adjustParams updates parameters used to create transports upon @@ -1467,7 +1510,39 @@ func (ac *addrConn) getState() connectivity.State { } func (ac *addrConn) ChannelzMetric() *channelz.ChannelInternalMetric { - return &channelz.ChannelInternalMetric{} + ac.mu.Lock() + addr := ac.curAddr.Addr + ac.mu.Unlock() + state := ac.getState() + ac.czmu.RLock() + defer ac.czmu.RUnlock() + return &channelz.ChannelInternalMetric{ + State: state, + Target: addr, + CallsStarted: ac.callsStarted, + CallsSucceeded: ac.callsSucceeded, + CallsFailed: ac.callsFailed, + LastCallStartedTimestamp: ac.lastCallStartedTime, + } +} + +func (ac *addrConn) incrCallsStarted() { + ac.czmu.Lock() + ac.callsStarted++ + ac.lastCallStartedTime = time.Now() + ac.czmu.Unlock() +} + +func (ac *addrConn) incrCallsSucceeded() { + ac.czmu.Lock() + ac.callsSucceeded++ + ac.czmu.Unlock() +} + +func (ac *addrConn) incrCallsFailed() { + ac.czmu.Lock() + ac.callsFailed++ + ac.czmu.Unlock() } // ErrClientConnTimeout indicates that the ClientConn cannot establish the diff --git a/picker_wrapper.go b/picker_wrapper.go index 4d0082593..3a466861e 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -19,10 +19,12 @@ package grpc import ( + "io" "sync" "golang.org/x/net/context" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/channelz" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/status" @@ -74,6 +76,23 @@ func (bp *pickerWrapper) updatePicker(p balancer.Picker) { bp.mu.Unlock() } +func doneChannelzWrapper(acw *acBalancerWrapper, done func(balancer.DoneInfo)) func(balancer.DoneInfo) { + acw.mu.Lock() + ac := acw.ac + acw.mu.Unlock() + ac.incrCallsStarted() + return func(b balancer.DoneInfo) { + if b.Err != nil && b.Err != io.EOF { + ac.incrCallsFailed() + } else { + ac.incrCallsSucceeded() + } + if done != nil { + done(b) + } + } +} + // pick returns the transport that will be used for the RPC. // It may block in the following cases: // - there's no picker @@ -137,6 +156,9 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer. continue } if t, ok := acw.getAddrConn().getReadyTransport(); ok { + if channelz.IsOn() { + return t, doneChannelzWrapper(acw, done), nil + } return t, done, nil } grpclog.Infof("blockingPicker: the picked transport is not ready, loop back to repick") diff --git a/server.go b/server.go index b5363572a..7ccd5662d 100644 --- a/server.go +++ b/server.go @@ -106,7 +106,12 @@ type Server struct { channelzRemoveOnce sync.Once serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop - channelzID int64 // channelz unique identification number + channelzID int64 // channelz unique identification number + czmu sync.RWMutex + callsStarted int64 + callsFailed int64 + callsSucceeded int64 + lastCallStartedTime time.Time } type options struct { @@ -473,7 +478,9 @@ type listenSocket struct { } func (l *listenSocket) ChannelzMetric() *channelz.SocketInternalMetric { - return &channelz.SocketInternalMetric{} + return &channelz.SocketInternalMetric{ + LocalAddr: l.Listener.Addr(), + } } func (l *listenSocket) Close() error { @@ -508,12 +515,6 @@ func (s *Server) Serve(lis net.Listener) error { // Stop or GracefulStop called; block until done and return nil. case <-s.quit: <-s.done - - s.channelzRemoveOnce.Do(func() { - if channelz.IsOn() { - channelz.RemoveEntry(s.channelzID) - } - }) default: } }() @@ -794,7 +795,33 @@ func (s *Server) removeConn(c io.Closer) { // ChannelzMetric returns ServerInternalMetric of current server. // This is an EXPERIMENTAL API. func (s *Server) ChannelzMetric() *channelz.ServerInternalMetric { - return &channelz.ServerInternalMetric{} + s.czmu.RLock() + defer s.czmu.RUnlock() + return &channelz.ServerInternalMetric{ + CallsStarted: s.callsStarted, + CallsSucceeded: s.callsSucceeded, + CallsFailed: s.callsFailed, + LastCallStartedTimestamp: s.lastCallStartedTime, + } +} + +func (s *Server) incrCallsStarted() { + s.czmu.Lock() + s.callsStarted++ + s.lastCallStartedTime = time.Now() + s.czmu.Unlock() +} + +func (s *Server) incrCallsSucceeded() { + s.czmu.Lock() + s.callsSucceeded++ + s.czmu.Unlock() +} + +func (s *Server) incrCallsFailed() { + s.czmu.Lock() + s.callsFailed++ + s.czmu.Unlock() } func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { @@ -821,6 +848,16 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { + if channelz.IsOn() { + s.incrCallsStarted() + defer func() { + if err != nil && err != io.EOF { + s.incrCallsFailed() + } else { + s.incrCallsSucceeded() + } + }() + } sh := s.opts.statsHandler if sh != nil { beginTime := time.Now() @@ -915,6 +952,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } return err } + if channelz.IsOn() { + t.IncrMsgRecv() + } if st := checkRecvPayload(pf, stream.RecvCompress(), dc != nil || decomp != nil); st != nil { if e := t.WriteStatus(stream, st); e != nil { grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e) @@ -1014,6 +1054,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } return err } + if channelz.IsOn() { + t.IncrMsgSent() + } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) } @@ -1024,6 +1067,16 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { + if channelz.IsOn() { + s.incrCallsStarted() + defer func() { + if err != nil && err != io.EOF { + s.incrCallsFailed() + } else { + s.incrCallsSucceeded() + } + }() + } sh := s.opts.statsHandler if sh != nil { beginTime := time.Now() diff --git a/stream.go b/stream.go index 0c33a1222..82921a15a 100644 --- a/stream.go +++ b/stream.go @@ -27,6 +27,7 @@ import ( "golang.org/x/net/context" "golang.org/x/net/trace" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/channelz" "google.golang.org/grpc/codes" "google.golang.org/grpc/encoding" "google.golang.org/grpc/metadata" @@ -121,6 +122,14 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { + if channelz.IsOn() { + cc.incrCallsStarted() + defer func() { + if err != nil { + cc.incrCallsFailed() + } + }() + } c := defaultCallInfo() mc := cc.GetMethodConfig(method) if mc.WaitForReady != nil { @@ -272,6 +281,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth cs := &clientStream{ opts: opts, c: c, + cc: cc, desc: desc, codec: c.codec, cp: cp, @@ -313,6 +323,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth type clientStream struct { opts []CallOption c *callInfo + cc *ClientConn desc *StreamDesc codec baseCodec @@ -401,6 +412,13 @@ func (cs *clientStream) finish(err error) { } cs.finished = true cs.mu.Unlock() + if channelz.IsOn() { + if err != nil { + cs.cc.incrCallsFailed() + } else { + cs.cc.incrCallsSucceeded() + } + } // TODO(retry): commit current attempt if necessary. cs.attempt.finish(err) for _, o := range cs.opts { @@ -470,6 +488,9 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) { outPayload.SentTime = time.Now() a.statsHandler.HandleRPC(a.ctx, outPayload) } + if channelz.IsOn() { + a.t.IncrMsgSent() + } return nil } return io.EOF @@ -525,6 +546,9 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) { if inPayload != nil { a.statsHandler.HandleRPC(a.ctx, inPayload) } + if channelz.IsOn() { + a.t.IncrMsgRecv() + } if cs.desc.ServerStreams { // Subsequent messages should be received by subsequent RecvMsg calls. return nil @@ -668,6 +692,9 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { st, _ := status.FromError(toRPCErr(err)) ss.t.WriteStatus(ss.s, st) } + if channelz.IsOn() && err == nil { + ss.t.IncrMsgSent() + } }() var outPayload *stats.OutPayload if ss.statsHandler != nil { @@ -708,6 +735,9 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { st, _ := status.FromError(toRPCErr(err)) ss.t.WriteStatus(ss.s, st) } + if channelz.IsOn() && err == nil { + ss.t.IncrMsgRecv() + } }() var inPayload *stats.InPayload if ss.statsHandler != nil { diff --git a/test/channelz_test.go b/test/channelz_test.go index 0492e3138..82f1a8f09 100644 --- a/test/channelz_test.go +++ b/test/channelz_test.go @@ -21,14 +21,19 @@ package test import ( "fmt" "net" + "sync" "testing" "time" + "golang.org/x/net/context" + "golang.org/x/net/http2" "google.golang.org/grpc" - "google.golang.org/grpc/channelz" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" "google.golang.org/grpc/test/leakcheck" ) @@ -408,3 +413,820 @@ func TestCZRecusivelyDeletionOfEntry(t *testing.T) { t.Fatalf("There should be no TopChannel entry") } } + +func TestCZChannelMetrics(t *testing.T) { + defer leakcheck.Check(t) + channelz.NewChannelzStorage() + e := tcpClearRREnv + num := 3 // number of backends + te := newTest(t, e) + te.maxClientSendMsgSize = newInt(8) + var svrAddrs []resolver.Address + te.startServers(&testServer{security: e.security}, num) + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + for _, a := range te.srvAddrs { + svrAddrs = append(svrAddrs, resolver.Address{Addr: a}) + } + r.InitialAddrs(svrAddrs) + te.resolverScheme = r.Scheme() + cc := te.clientConn() + defer te.tearDown() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + + const smallSize = 1 + const largeSize = 8 + + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseSize: int32(smallSize), + Payload: largePayload, + } + + if _, err := tc.UnaryCall(context.Background(), req); err == nil || status.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + stream, err := tc.FullDuplexCall(context.Background()) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + defer stream.CloseSend() + // Here, we just wait for all sockets to be up. In the future, if we implement + // IDLE, we may need to make several rpc calls to create the sockets. + if err := verifyResultWithDelay(func() (bool, error) { + tcs, _ := channelz.GetTopChannels(0) + if len(tcs) != 1 { + return false, fmt.Errorf("There should only be one top channel, not %d", len(tcs)) + } + if len(tcs[0].SubChans) != num { + return false, fmt.Errorf("There should be %d subchannel not %d", num, len(tcs[0].SubChans)) + } + var cst, csu, cf int64 + for k := range tcs[0].SubChans { + sc := channelz.GetSubChannel(k) + if sc == nil { + return false, fmt.Errorf("got subchannel") + } + cst += sc.ChannelData.CallsStarted + csu += sc.ChannelData.CallsSucceeded + cf += sc.ChannelData.CallsFailed + } + if cst != 3 { + return false, fmt.Errorf("There should be 3 CallsStarted not %d", cst) + } + if csu != 1 { + return false, fmt.Errorf("There should be 1 CallsSucceeded not %d", csu) + } + if cf != 1 { + return false, fmt.Errorf("There should be 1 CallsFailed not %d", cf) + } + if tcs[0].ChannelData.CallsStarted != 3 { + return false, fmt.Errorf("There should be 3 CallsStarted not %d", tcs[0].ChannelData.CallsStarted) + } + if tcs[0].ChannelData.CallsSucceeded != 1 { + return false, fmt.Errorf("There should be 1 CallsSucceeded not %d", tcs[0].ChannelData.CallsSucceeded) + } + if tcs[0].ChannelData.CallsFailed != 1 { + return false, fmt.Errorf("There should be 1 CallsFailed not %d", tcs[0].ChannelData.CallsFailed) + } + return true, nil + }); err != nil { + t.Fatal(err) + } +} + +func TestCZServerMetrics(t *testing.T) { + defer leakcheck.Check(t) + channelz.NewChannelzStorage() + e := tcpClearRREnv + te := newTest(t, e) + te.maxServerReceiveMsgSize = newInt(8) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + + const smallSize = 1 + const largeSize = 8 + + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseSize: int32(smallSize), + Payload: largePayload, + } + if _, err := tc.UnaryCall(context.Background(), req); err == nil || status.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } + + stream, err := tc.FullDuplexCall(context.Background()) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + defer stream.CloseSend() + + if err := verifyResultWithDelay(func() (bool, error) { + ss, _ := channelz.GetServers(0) + if len(ss) != 1 { + return false, fmt.Errorf("There should only be one server, not %d", len(ss)) + } + if ss[0].ServerData.CallsStarted != 3 { + return false, fmt.Errorf("There should be 3 CallsStarted not %d", ss[0].ServerData.CallsStarted) + } + if ss[0].ServerData.CallsSucceeded != 1 { + return false, fmt.Errorf("There should be 1 CallsSucceeded not %d", ss[0].ServerData.CallsSucceeded) + } + if ss[0].ServerData.CallsFailed != 1 { + return false, fmt.Errorf("There should be 1 CallsFailed not %d", ss[0].ServerData.CallsFailed) + } + return true, nil + }); err != nil { + t.Fatal(err) + } +} + +type testServiceClientWrapper struct { + testpb.TestServiceClient + mu sync.RWMutex + streamsCreated int +} + +func (t *testServiceClientWrapper) getCurrentStreamID() uint32 { + t.mu.RLock() + defer t.mu.RUnlock() + return uint32(2*t.streamsCreated - 1) +} + +func (t *testServiceClientWrapper) EmptyCall(ctx context.Context, in *testpb.Empty, opts ...grpc.CallOption) (*testpb.Empty, error) { + t.mu.Lock() + defer t.mu.Unlock() + t.streamsCreated++ + return t.TestServiceClient.EmptyCall(ctx, in, opts...) +} + +func (t *testServiceClientWrapper) UnaryCall(ctx context.Context, in *testpb.SimpleRequest, opts ...grpc.CallOption) (*testpb.SimpleResponse, error) { + t.mu.Lock() + defer t.mu.Unlock() + t.streamsCreated++ + return t.TestServiceClient.UnaryCall(ctx, in, opts...) +} + +func (t *testServiceClientWrapper) StreamingOutputCall(ctx context.Context, in *testpb.StreamingOutputCallRequest, opts ...grpc.CallOption) (testpb.TestService_StreamingOutputCallClient, error) { + t.mu.Lock() + defer t.mu.Unlock() + t.streamsCreated++ + return t.TestServiceClient.StreamingOutputCall(ctx, in, opts...) +} + +func (t *testServiceClientWrapper) StreamingInputCall(ctx context.Context, opts ...grpc.CallOption) (testpb.TestService_StreamingInputCallClient, error) { + t.mu.Lock() + defer t.mu.Unlock() + t.streamsCreated++ + return t.TestServiceClient.StreamingInputCall(ctx, opts...) +} + +func (t *testServiceClientWrapper) FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (testpb.TestService_FullDuplexCallClient, error) { + t.mu.Lock() + defer t.mu.Unlock() + t.streamsCreated++ + return t.TestServiceClient.FullDuplexCall(ctx, opts...) +} + +func (t *testServiceClientWrapper) HalfDuplexCall(ctx context.Context, opts ...grpc.CallOption) (testpb.TestService_HalfDuplexCallClient, error) { + t.mu.Lock() + defer t.mu.Unlock() + t.streamsCreated++ + return t.TestServiceClient.HalfDuplexCall(ctx, opts...) +} + +func doSuccessfulUnaryCall(tc testpb.TestServiceClient, t *testing.T) { + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } +} + +func doStreamingInputCallWithLargePayload(tc testpb.TestServiceClient, t *testing.T) { + s, err := tc.StreamingInputCall(context.Background()) + if err != nil { + t.Fatalf("TestService/StreamingInputCall(_) = _, %v, want ", err) + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 10000) + if err != nil { + t.Fatal(err) + } + s.Send(&testpb.StreamingInputCallRequest{Payload: payload}) +} + +func doServerSideFailedUnaryCall(tc testpb.TestServiceClient, t *testing.T) { + const smallSize = 1 + const largeSize = 2000 + + largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseSize: int32(smallSize), + Payload: largePayload, + } + if _, err := tc.UnaryCall(context.Background(), req); err == nil || status.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } +} + +// This func is to be used to test server side counting of streams succeeded. +// It cannot be used for client side counting due to race between receiving +// server trailer (streamsSucceeded++) and CloseStream on error (streamsFailed++) +// on client side. +func doClientSideFailedUnaryCall(tc testpb.TestServiceClient, t *testing.T) { + const smallSize = 1 + const largeSize = 2000 + + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseSize: int32(largeSize), + Payload: smallPayload, + } + if _, err := tc.UnaryCall(context.Background(), req); err == nil || status.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted) + } +} + +func doClientSideInitiatedFailedStream(tc testpb.TestServiceClient, t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + stream, err := tc.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want ", err) + } + + const smallSize = 1 + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseParameters: []*testpb.ResponseParameters{ + {Size: smallSize}, + }, + Payload: smallPayload, + } + + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = %v, want ", stream, err) + } + // By canceling the call, the client will send rst_stream to end the call, and + // the stream will failed as a result. + cancel() +} + +// This func is to be used to test client side counting of failed streams. +func doServerSideInitiatedFailedStreamWithRSTStream(tc testpb.TestServiceClient, t *testing.T, l *listenerWrapper) { + stream, err := tc.FullDuplexCall(context.Background()) + if err != nil { + t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want ", err) + } + + const smallSize = 1 + smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize) + if err != nil { + t.Fatal(err) + } + + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseParameters: []*testpb.ResponseParameters{ + {Size: smallSize}, + }, + Payload: smallPayload, + } + + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = %v, want ", stream, err) + } + + rcw := l.getLastConn() + + if rcw != nil { + rcw.writeRSTStream(tc.(*testServiceClientWrapper).getCurrentStreamID(), http2.ErrCodeCancel) + } + if _, err := stream.Recv(); err == nil { + t.Fatalf("%v.Recv() = %v, want ", stream, err) + } +} + +// this func is to be used to test client side counting of failed streams. +func doServerSideInitiatedFailedStreamWithGoAway(tc testpb.TestServiceClient, t *testing.T, l *listenerWrapper) { + // This call is just to keep the transport from shutting down (socket will be deleted + // in this case, and we will not be able to get metrics). + s, err := tc.FullDuplexCall(context.Background()) + if err != nil { + t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want ", err) + } + if err := s.Send(&testpb.StreamingOutputCallRequest{ResponseParameters: []*testpb.ResponseParameters{ + { + Size: 1, + }, + }}); err != nil { + t.Fatalf("s.Send() failed with error: %v", err) + } + if _, err := s.Recv(); err != nil { + t.Fatalf("s.Recv() failed with error: %v", err) + } + + s, err = tc.FullDuplexCall(context.Background()) + if err != nil { + t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want ", err) + } + if err := s.Send(&testpb.StreamingOutputCallRequest{ResponseParameters: []*testpb.ResponseParameters{ + { + Size: 1, + }, + }}); err != nil { + t.Fatalf("s.Send() failed with error: %v", err) + } + if _, err := s.Recv(); err != nil { + t.Fatalf("s.Recv() failed with error: %v", err) + } + + rcw := l.getLastConn() + if rcw != nil { + rcw.writeGoAway(tc.(*testServiceClientWrapper).getCurrentStreamID()-2, http2.ErrCodeCancel, []byte{}) + } + if _, err := s.Recv(); err == nil { + t.Fatalf("%v.Recv() = %v, want ", s, err) + } +} + +// this func is to be used to test client side counting of failed streams. +func doServerSideInitiatedFailedStreamWithClientBreakFlowControl(tc testpb.TestServiceClient, t *testing.T, dw *dialerWrapper) { + stream, err := tc.FullDuplexCall(context.Background()) + if err != nil { + t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want ", err) + } + // sleep here to make sure header frame being sent before the the data frame we write directly below. + time.Sleep(10 * time.Millisecond) + payload := make([]byte, 65537, 65537) + dw.getRawConnWrapper().writeRawFrame(http2.FrameData, 0, tc.(*testServiceClientWrapper).getCurrentStreamID(), payload) + if _, err := stream.Recv(); err == nil || status.Code(err) != codes.ResourceExhausted { + t.Fatalf("%v.Recv() = %v, want error code: %v", stream, err, codes.ResourceExhausted) + } +} + +func doIdleCallToInvokeKeepAlive(tc testpb.TestServiceClient, t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + _, err := tc.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want ", err) + } + // 2500ms allow for 2 keepalives (1000ms per round trip) + time.Sleep(2500 * time.Millisecond) + cancel() +} + +func TestCZClientSocketMetricsStreamsAndMessagesCount(t *testing.T) { + defer leakcheck.Check(t) + channelz.NewChannelzStorage() + e := tcpClearRREnv + te := newTest(t, e) + te.maxServerReceiveMsgSize = newInt(20) + te.maxClientReceiveMsgSize = newInt(20) + rcw := te.startServerWithConnControl(&testServer{security: e.security}) + defer te.tearDown() + cc := te.clientConn() + tc := &testServiceClientWrapper{TestServiceClient: testpb.NewTestServiceClient(cc)} + + doSuccessfulUnaryCall(tc, t) + var scID, skID int64 + if err := verifyResultWithDelay(func() (bool, error) { + tchan, _ := channelz.GetTopChannels(0) + if len(tchan) != 1 { + return false, fmt.Errorf("There should only be one top channel, not %d", len(tchan)) + } + if len(tchan[0].SubChans) != 1 { + return false, fmt.Errorf("There should only be one subchannel under top channel %d, not %d", tchan[0].ID, len(tchan[0].SubChans)) + } + + for scID = range tchan[0].SubChans { + break + } + sc := channelz.GetSubChannel(scID) + if sc == nil { + return false, fmt.Errorf("There should only be one socket under subchannel %d, not 0", scID) + } + if len(sc.Sockets) != 1 { + return false, fmt.Errorf("There should only be one socket under subchannel %d, not %d", sc.ID, len(sc.Sockets)) + } + for skID = range sc.Sockets { + break + } + skt := channelz.GetSocket(skID) + sktData := skt.SocketData + if sktData.StreamsStarted != 1 || sktData.StreamsSucceeded != 1 || sktData.MessagesSent != 1 || sktData.MessagesReceived != 1 { + return false, fmt.Errorf("channelz.GetSocket(%d), want (StreamsStarted, StreamsSucceeded, MessagesSent, MessagesReceived) = (1, 1, 1, 1), got (%d, %d, %d, %d)", skt.ID, sktData.StreamsStarted, sktData.StreamsSucceeded, sktData.MessagesSent, sktData.MessagesReceived) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + doServerSideFailedUnaryCall(tc, t) + if err := verifyResultWithDelay(func() (bool, error) { + skt := channelz.GetSocket(skID) + sktData := skt.SocketData + if sktData.StreamsStarted != 2 || sktData.StreamsSucceeded != 2 || sktData.MessagesSent != 2 || sktData.MessagesReceived != 1 { + return false, fmt.Errorf("channelz.GetSocket(%d), want (StreamsStarted, StreamsSucceeded, MessagesSent, MessagesReceived) = (2, 2, 2, 1), got (%d, %d, %d, %d)", skt.ID, sktData.StreamsStarted, sktData.StreamsSucceeded, sktData.MessagesSent, sktData.MessagesReceived) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + doClientSideInitiatedFailedStream(tc, t) + if err := verifyResultWithDelay(func() (bool, error) { + skt := channelz.GetSocket(skID) + sktData := skt.SocketData + if sktData.StreamsStarted != 3 || sktData.StreamsSucceeded != 2 || sktData.StreamsFailed != 1 || sktData.MessagesSent != 3 || sktData.MessagesReceived != 2 { + return false, fmt.Errorf("channelz.GetSocket(%d), want (StreamsStarted, StreamsSucceeded, StreamsFailed, MessagesSent, MessagesReceived) = (3, 2, 1, 3, 2), got (%d, %d, %d, %d, %d)", skt.ID, sktData.StreamsStarted, sktData.StreamsSucceeded, sktData.StreamsFailed, sktData.MessagesSent, sktData.MessagesReceived) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + doServerSideInitiatedFailedStreamWithRSTStream(tc, t, rcw) + if err := verifyResultWithDelay(func() (bool, error) { + skt := channelz.GetSocket(skID) + sktData := skt.SocketData + if sktData.StreamsStarted != 4 || sktData.StreamsSucceeded != 2 || sktData.StreamsFailed != 2 || sktData.MessagesSent != 4 || sktData.MessagesReceived != 3 { + return false, fmt.Errorf("channelz.GetSocket(%d), want (StreamsStarted, StreamsSucceeded, StreamsFailed, MessagesSent, MessagesReceived) = (4, 2, 2, 4, 3), got (%d, %d, %d, %d, %d)", skt.ID, sktData.StreamsStarted, sktData.StreamsSucceeded, sktData.StreamsFailed, sktData.MessagesSent, sktData.MessagesReceived) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + doServerSideInitiatedFailedStreamWithGoAway(tc, t, rcw) + if err := verifyResultWithDelay(func() (bool, error) { + skt := channelz.GetSocket(skID) + sktData := skt.SocketData + if sktData.StreamsStarted != 6 || sktData.StreamsSucceeded != 2 || sktData.StreamsFailed != 3 || sktData.MessagesSent != 6 || sktData.MessagesReceived != 5 { + return false, fmt.Errorf("channelz.GetSocket(%d), want (StreamsStarted, StreamsSucceeded, StreamsFailed, MessagesSent, MessagesReceived) = (6, 2, 3, 6, 5), got (%d, %d, %d, %d, %d)", skt.ID, sktData.StreamsStarted, sktData.StreamsSucceeded, sktData.StreamsFailed, sktData.MessagesSent, sktData.MessagesReceived) + } + return true, nil + }); err != nil { + t.Fatal(err) + } +} + +// This test is to complete TestCZClientSocketMetricsStreamsAndMessagesCount and +// TestCZServerSocketMetricsStreamsAndMessagesCount by adding the test case of +// server sending RST_STREAM to client due to client side flow control violation. +// It is separated from other cases due to setup incompatibly, i.e. max receive +// size violation will mask flow control violation. +func TestCZClientAndServerSocketMetricsStreamsCountFlowControlRSTStream(t *testing.T) { + defer leakcheck.Check(t) + channelz.NewChannelzStorage() + e := tcpClearRREnv + te := newTest(t, e) + te.serverInitialWindowSize = 65536 + // Avoid overflowing connection level flow control window, which will lead to + // transport being closed. + te.serverInitialConnWindowSize = 65536 * 2 + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + cc, dw := te.clientConnWithConnControl() + tc := &testServiceClientWrapper{TestServiceClient: testpb.NewTestServiceClient(cc)} + + doServerSideInitiatedFailedStreamWithClientBreakFlowControl(tc, t, dw) + if err := verifyResultWithDelay(func() (bool, error) { + tchan, _ := channelz.GetTopChannels(0) + if len(tchan) != 1 { + return false, fmt.Errorf("There should only be one top channel, not %d", len(tchan)) + } + if len(tchan[0].SubChans) != 1 { + return false, fmt.Errorf("There should only be one subchannel under top channel %d, not %d", tchan[0].ID, len(tchan[0].SubChans)) + } + var id int64 + for id = range tchan[0].SubChans { + break + } + sc := channelz.GetSubChannel(id) + if sc == nil { + return false, fmt.Errorf("There should only be one socket under subchannel %d, not 0", id) + } + if len(sc.Sockets) != 1 { + return false, fmt.Errorf("There should only be one socket under subchannel %d, not %d", sc.ID, len(sc.Sockets)) + } + for id = range sc.Sockets { + break + } + skt := channelz.GetSocket(id) + sktData := skt.SocketData + if sktData.StreamsStarted != 1 || sktData.StreamsSucceeded != 0 || sktData.StreamsFailed != 1 { + return false, fmt.Errorf("channelz.GetSocket(%d), want (StreamsStarted, StreamsSucceeded, StreamsFailed) = (1, 0, 1), got (%d, %d, %d)", skt.ID, sktData.StreamsStarted, sktData.StreamsSucceeded, sktData.StreamsFailed) + } + ss, _ := channelz.GetServers(0) + if len(ss) != 1 { + return false, fmt.Errorf("There should only be one server, not %d", len(ss)) + } + + ns, _ := channelz.GetServerSockets(ss[0].ID, 0) + if len(ns) != 1 { + return false, fmt.Errorf("There should be one server normal socket, not %d", len(ns)) + } + sktData = ns[0].SocketData + if sktData.StreamsStarted != 1 || sktData.StreamsSucceeded != 0 || sktData.StreamsFailed != 1 { + return false, fmt.Errorf("Server socket metric with ID %d, want (StreamsStarted, StreamsSucceeded, StreamsFailed) = (1, 0, 1), got (%d, %d, %d)", ns[0].ID, sktData.StreamsStarted, sktData.StreamsSucceeded, sktData.StreamsFailed) + } + return true, nil + }); err != nil { + t.Fatal(err) + } +} + +func TestCZClientAndServerSocketMetricsFlowControl(t *testing.T) { + defer leakcheck.Check(t) + channelz.NewChannelzStorage() + e := tcpClearRREnv + te := newTest(t, e) + // disable BDP + te.serverInitialWindowSize = 65536 + te.serverInitialConnWindowSize = 65536 + te.clientInitialWindowSize = 65536 + te.clientInitialConnWindowSize = 65536 + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + + for i := 0; i < 10; i++ { + doSuccessfulUnaryCall(tc, t) + } + + var cliSktID, svrSktID int64 + if err := verifyResultWithDelay(func() (bool, error) { + tchan, _ := channelz.GetTopChannels(0) + if len(tchan) != 1 { + return false, fmt.Errorf("There should only be one top channel, not %d", len(tchan)) + } + if len(tchan[0].SubChans) != 1 { + return false, fmt.Errorf("There should only be one subchannel under top channel %d, not %d", tchan[0].ID, len(tchan[0].SubChans)) + } + var id int64 + for id = range tchan[0].SubChans { + break + } + sc := channelz.GetSubChannel(id) + if sc == nil { + return false, fmt.Errorf("There should only be one socket under subchannel %d, not 0", id) + } + if len(sc.Sockets) != 1 { + return false, fmt.Errorf("There should only be one socket under subchannel %d, not %d", sc.ID, len(sc.Sockets)) + } + for id = range sc.Sockets { + break + } + skt := channelz.GetSocket(id) + sktData := skt.SocketData + // 65536 - 5 (Length-Prefixed-Message size) * 10 = 65486 + if sktData.LocalFlowControlWindow != 65486 || sktData.RemoteFlowControlWindow != 65486 { + return false, fmt.Errorf("Client: (LocalFlowControlWindow, RemoteFlowControlWindow) size should be (65536, 65486), not (%d, %d)", sktData.LocalFlowControlWindow, sktData.RemoteFlowControlWindow) + } + ss, _ := channelz.GetServers(0) + if len(ss) != 1 { + return false, fmt.Errorf("There should only be one server, not %d", len(ss)) + } + ns, _ := channelz.GetServerSockets(ss[0].ID, 0) + sktData = ns[0].SocketData + if sktData.LocalFlowControlWindow != 65486 || sktData.RemoteFlowControlWindow != 65486 { + return false, fmt.Errorf("Server: (LocalFlowControlWindow, RemoteFlowControlWindow) size should be (65536, 65486), not (%d, %d)", sktData.LocalFlowControlWindow, sktData.RemoteFlowControlWindow) + } + cliSktID, svrSktID = id, ss[0].ID + return true, nil + }); err != nil { + t.Fatal(err) + } + + doStreamingInputCallWithLargePayload(tc, t) + + if err := verifyResultWithDelay(func() (bool, error) { + skt := channelz.GetSocket(cliSktID) + sktData := skt.SocketData + // Local: 65536 - 5 (Length-Prefixed-Message size) * 10 = 65486 + // Remote: 65536 - 5 (Length-Prefixed-Message size) * 10 - 10011 = 55475 + if sktData.LocalFlowControlWindow != 65486 || sktData.RemoteFlowControlWindow != 55475 { + return false, fmt.Errorf("Client: (LocalFlowControlWindow, RemoteFlowControlWindow) size should be (65486, 55475), not (%d, %d)", sktData.LocalFlowControlWindow, sktData.RemoteFlowControlWindow) + } + ss, _ := channelz.GetServers(0) + if len(ss) != 1 { + return false, fmt.Errorf("There should only be one server, not %d", len(ss)) + } + ns, _ := channelz.GetServerSockets(svrSktID, 0) + sktData = ns[0].SocketData + if sktData.LocalFlowControlWindow != 55475 || sktData.RemoteFlowControlWindow != 65486 { + return false, fmt.Errorf("Server: (LocalFlowControlWindow, RemoteFlowControlWindow) size should be (55475, 65486), not (%d, %d)", sktData.LocalFlowControlWindow, sktData.RemoteFlowControlWindow) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + // triggers transport flow control window update on server side, since unacked + // bytes should be larger than limit now. i.e. 50 + 20022 > 65536/4. + doStreamingInputCallWithLargePayload(tc, t) + if err := verifyResultWithDelay(func() (bool, error) { + skt := channelz.GetSocket(cliSktID) + sktData := skt.SocketData + // Local: 65536 - 5 (Length-Prefixed-Message size) * 10 = 65486 + // Remote: 65536 + if sktData.LocalFlowControlWindow != 65486 || sktData.RemoteFlowControlWindow != 65536 { + return false, fmt.Errorf("Client: (LocalFlowControlWindow, RemoteFlowControlWindow) size should be (65486, 65536), not (%d, %d)", sktData.LocalFlowControlWindow, sktData.RemoteFlowControlWindow) + } + ss, _ := channelz.GetServers(0) + if len(ss) != 1 { + return false, fmt.Errorf("There should only be one server, not %d", len(ss)) + } + ns, _ := channelz.GetServerSockets(svrSktID, 0) + sktData = ns[0].SocketData + if sktData.LocalFlowControlWindow != 65536 || sktData.RemoteFlowControlWindow != 65486 { + return false, fmt.Errorf("Server: (LocalFlowControlWindow, RemoteFlowControlWindow) size should be (65536, 65486), not (%d, %d)", sktData.LocalFlowControlWindow, sktData.RemoteFlowControlWindow) + } + return true, nil + }); err != nil { + t.Fatal(err) + } +} + +func TestCZClientSocketMetricsKeepAlive(t *testing.T) { + defer leakcheck.Check(t) + channelz.NewChannelzStorage() + e := tcpClearRREnv + te := newTest(t, e) + te.cliKeepAlive = &keepalive.ClientParameters{Time: 500 * time.Millisecond, Timeout: 500 * time.Millisecond} + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + doIdleCallToInvokeKeepAlive(tc, t) + + if err := verifyResultWithDelay(func() (bool, error) { + tchan, _ := channelz.GetTopChannels(0) + if len(tchan) != 1 { + return false, fmt.Errorf("There should only be one top channel, not %d", len(tchan)) + } + if len(tchan[0].SubChans) != 1 { + return false, fmt.Errorf("There should only be one subchannel under top channel %d, not %d", tchan[0].ID, len(tchan[0].SubChans)) + } + var id int64 + for id = range tchan[0].SubChans { + break + } + sc := channelz.GetSubChannel(id) + if sc == nil { + return false, fmt.Errorf("There should only be one socket under subchannel %d, not 0", id) + } + if len(sc.Sockets) != 1 { + return false, fmt.Errorf("There should only be one socket under subchannel %d, not %d", sc.ID, len(sc.Sockets)) + } + for id = range sc.Sockets { + break + } + skt := channelz.GetSocket(id) + if skt.SocketData.KeepAlivesSent != 2 { // doIdleCallToInvokeKeepAlive func is set up to send 2 KeepAlives. + return false, fmt.Errorf("There should be 2 KeepAlives sent, not %d", skt.SocketData.KeepAlivesSent) + } + return true, nil + }); err != nil { + t.Fatal(err) + } +} + +func TestCZServerSocketMetricsStreamsAndMessagesCount(t *testing.T) { + defer leakcheck.Check(t) + channelz.NewChannelzStorage() + e := tcpClearRREnv + te := newTest(t, e) + te.maxServerReceiveMsgSize = newInt(20) + te.maxClientReceiveMsgSize = newInt(20) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + cc, _ := te.clientConnWithConnControl() + tc := &testServiceClientWrapper{TestServiceClient: testpb.NewTestServiceClient(cc)} + + var svrID int64 + if err := verifyResultWithDelay(func() (bool, error) { + ss, _ := channelz.GetServers(0) + if len(ss) != 1 { + return false, fmt.Errorf("There should only be one server, not %d", len(ss)) + } + svrID = ss[0].ID + return true, nil + }); err != nil { + t.Fatal(err) + } + + doSuccessfulUnaryCall(tc, t) + if err := verifyResultWithDelay(func() (bool, error) { + ns, _ := channelz.GetServerSockets(svrID, 0) + sktData := ns[0].SocketData + if sktData.StreamsStarted != 1 || sktData.StreamsSucceeded != 1 || sktData.StreamsFailed != 0 || sktData.MessagesSent != 1 || sktData.MessagesReceived != 1 { + return false, fmt.Errorf("Server socket metric with ID %d, want (StreamsStarted, StreamsSucceeded, MessagesSent, MessagesReceived) = (1, 1, 1, 1), got (%d, %d, %d, %d, %d)", ns[0].ID, sktData.StreamsStarted, sktData.StreamsSucceeded, sktData.StreamsFailed, sktData.MessagesSent, sktData.MessagesReceived) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + doServerSideFailedUnaryCall(tc, t) + if err := verifyResultWithDelay(func() (bool, error) { + ns, _ := channelz.GetServerSockets(svrID, 0) + sktData := ns[0].SocketData + if sktData.StreamsStarted != 2 || sktData.StreamsSucceeded != 2 || sktData.StreamsFailed != 0 || sktData.MessagesSent != 1 || sktData.MessagesReceived != 1 { + return false, fmt.Errorf("Server socket metric with ID %d, want (StreamsStarted, StreamsSucceeded, StreamsFailed, MessagesSent, MessagesReceived) = (2, 2, 0, 1, 1), got (%d, %d, %d, %d, %d)", ns[0].ID, sktData.StreamsStarted, sktData.StreamsSucceeded, sktData.StreamsFailed, sktData.MessagesSent, sktData.MessagesReceived) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + doClientSideFailedUnaryCall(tc, t) + if err := verifyResultWithDelay(func() (bool, error) { + ns, _ := channelz.GetServerSockets(svrID, 0) + sktData := ns[0].SocketData + if sktData.StreamsStarted != 3 || sktData.StreamsSucceeded != 3 || sktData.StreamsFailed != 0 || sktData.MessagesSent != 2 || sktData.MessagesReceived != 2 { + return false, fmt.Errorf("Server socket metric with ID %d, want (StreamsStarted, StreamsSucceeded, StreamsFailed, MessagesSent, MessagesReceived) = (3, 3, 0, 2, 2), got (%d, %d, %d, %d, %d)", ns[0].ID, sktData.StreamsStarted, sktData.StreamsSucceeded, sktData.StreamsFailed, sktData.MessagesSent, sktData.MessagesReceived) + } + return true, nil + }); err != nil { + t.Fatal(err) + } + + doClientSideInitiatedFailedStream(tc, t) + if err := verifyResultWithDelay(func() (bool, error) { + ns, _ := channelz.GetServerSockets(svrID, 0) + sktData := ns[0].SocketData + if sktData.StreamsStarted != 4 || sktData.StreamsSucceeded != 3 || sktData.StreamsFailed != 1 || sktData.MessagesSent != 3 || sktData.MessagesReceived != 3 { + return false, fmt.Errorf("Server socket metric with ID %d, want (StreamsStarted, StreamsSucceeded, StreamsFailed, MessagesSent, MessagesReceived) = (4, 3, 1, 3, 3), got (%d, %d, %d, %d, %d)", ns[0].ID, sktData.StreamsStarted, sktData.StreamsSucceeded, sktData.StreamsFailed, sktData.MessagesSent, sktData.MessagesReceived) + } + return true, nil + }); err != nil { + t.Fatal(err) + } +} + +func TestCZServerSocketMetricsKeepAlive(t *testing.T) { + defer leakcheck.Check(t) + channelz.NewChannelzStorage() + e := tcpClearRREnv + te := newTest(t, e) + te.svrKeepAlive = &keepalive.ServerParameters{Time: 500 * time.Millisecond, Timeout: 500 * time.Millisecond} + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + doIdleCallToInvokeKeepAlive(tc, t) + + if err := verifyResultWithDelay(func() (bool, error) { + ss, _ := channelz.GetServers(0) + if len(ss) != 1 { + return false, fmt.Errorf("There should be one server, not %d", len(ss)) + } + ns, _ := channelz.GetServerSockets(ss[0].ID, 0) + if len(ns) != 1 { + return false, fmt.Errorf("There should be one server normal socket, not %d", len(ns)) + } + if ns[0].SocketData.KeepAlivesSent != 2 { // doIdleCallToInvokeKeepAlive func is set up to send 2 KeepAlives. + return false, fmt.Errorf("There should be 2 KeepAlives sent, not %d", ns[0].SocketData.KeepAlivesSent) + } + return true, nil + }); err != nil { + t.Fatal(err) + } +} diff --git a/test/end2end_test.go b/test/end2end_test.go index e0e7df606..ace57cb72 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -55,6 +55,7 @@ import ( "google.golang.org/grpc/health" healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/internal" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" @@ -473,6 +474,8 @@ type test struct { perRPCCreds credentials.PerRPCCredentials customDialOptions []grpc.DialOption resolverScheme string + cliKeepAlive *keepalive.ClientParameters + svrKeepAlive *keepalive.ServerParameters // All test dialing is blocking by default. Set this to true if dial // should be non-blocking. @@ -495,14 +498,17 @@ func (te *test) tearDown() { te.cancel() te.cancel = nil } + if te.cc != nil { te.cc.Close() te.cc = nil } + if te.restoreLogs != nil { te.restoreLogs() te.restoreLogs = nil } + if te.srv != nil { te.srv.Stop() } @@ -526,9 +532,7 @@ func newTest(t *testing.T, e env) *test { return te } -// startServer starts a gRPC server listening. Callers should defer a -// call to te.tearDown to clean up. -func (te *test) startServer(ts testpb.TestServiceServer) { +func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network, address string) (net.Listener, error)) net.Listener { te.testServer = ts te.t.Logf("Running test in %s environment...", te.e.name) sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)} @@ -571,7 +575,7 @@ func (te *test) startServer(ts testpb.TestServiceServer) { la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now().UnixNano()) syscall.Unlink(la) } - lis, err := net.Listen(te.e.network, la) + lis, err := listen(te.e.network, la) if err != nil { te.t.Fatalf("Failed to listen: %v", err) } @@ -588,6 +592,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) { if te.customCodec != nil { sopts = append(sopts, grpc.CustomCodec(te.customCodec)) } + if te.svrKeepAlive != nil { + sopts = append(sopts, grpc.KeepaliveParams(*te.svrKeepAlive)) + } s := grpc.NewServer(sopts...) te.srv = s if te.e.httpHandler { @@ -612,6 +619,18 @@ func (te *test) startServer(ts testpb.TestServiceServer) { go s.Serve(lis) te.srvAddr = addr + return lis +} + +func (te *test) startServerWithConnControl(ts testpb.TestServiceServer) *listenerWrapper { + l := te.listenAndServe(ts, listenWithConnControl) + return l.(*listenerWrapper) +} + +// startServer starts a gRPC server listening. Callers should defer a +// call to te.tearDown to clean up. +func (te *test) startServer(ts testpb.TestServiceServer) { + te.listenAndServe(ts, net.Listen) } type nopCompressor struct { @@ -640,10 +659,7 @@ func (d *nopDecompressor) Type() string { return "nop" } -func (te *test) clientConn(opts ...grpc.DialOption) *grpc.ClientConn { - if te.cc != nil { - return te.cc - } +func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string) { opts = append(opts, grpc.WithDialer(te.e.dialer), grpc.WithUserAgent(te.userAgent)) if te.sc != nil { @@ -724,7 +740,35 @@ func (te *test) clientConn(opts ...grpc.DialOption) *grpc.ClientConn { if te.srvAddr == "" { te.srvAddr = "client.side.only.test" } + if te.cliKeepAlive != nil { + opts = append(opts, grpc.WithKeepaliveParams(*te.cliKeepAlive)) + } opts = append(opts, te.customDialOptions...) + return opts, scheme +} + +func (te *test) clientConnWithConnControl() (*grpc.ClientConn, *dialerWrapper) { + if te.cc != nil { + return te.cc, nil + } + opts, scheme := te.configDial() + dw := &dialerWrapper{} + // overwrite the dialer before + opts = append(opts, grpc.WithDialer(dw.dialer)) + var err error + te.cc, err = grpc.Dial(scheme+te.srvAddr, opts...) + if err != nil { + te.t.Fatalf("Dial(%q) = %v", scheme+te.srvAddr, err) + } + return te.cc, dw +} + +func (te *test) clientConn(opts ...grpc.DialOption) *grpc.ClientConn { + if te.cc != nil { + return te.cc + } + var scheme string + opts, scheme = te.configDial(opts...) var err error te.cc, err = grpc.Dial(scheme+te.srvAddr, opts...) if err != nil { diff --git a/test/rawConnWrapper.go b/test/rawConnWrapper.go new file mode 100644 index 000000000..8d4ee0203 --- /dev/null +++ b/test/rawConnWrapper.go @@ -0,0 +1,345 @@ +/* + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "bytes" + "fmt" + "io" + "net" + "strings" + "sync" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +type listenerWrapper struct { + net.Listener + mu sync.Mutex + rcw *rawConnWrapper +} + +func listenWithConnControl(network, address string) (net.Listener, error) { + l, err := net.Listen(network, address) + if err != nil { + return nil, err + } + return &listenerWrapper{Listener: l}, nil +} + +// Accept blocks until Dial is called, then returns a net.Conn for the server +// half of the connection. +func (l *listenerWrapper) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + l.mu.Lock() + l.rcw = newRawConnWrapperFromConn(c) + l.mu.Unlock() + return c, nil +} + +func (l *listenerWrapper) getLastConn() *rawConnWrapper { + l.mu.Lock() + defer l.mu.Unlock() + return l.rcw +} + +type dialerWrapper struct { + c net.Conn + rcw *rawConnWrapper +} + +func (d *dialerWrapper) dialer(target string, t time.Duration) (net.Conn, error) { + c, err := net.DialTimeout("tcp", target, t) + d.c = c + d.rcw = newRawConnWrapperFromConn(c) + return c, err +} + +func (d *dialerWrapper) getRawConnWrapper() *rawConnWrapper { + return d.rcw +} + +type rawConnWrapper struct { + cc io.ReadWriteCloser + fr *http2.Framer + + // writing headers: + headerBuf bytes.Buffer + hpackEnc *hpack.Encoder + + // reading frames: + frc chan http2.Frame + frErrc chan error + readTimer *time.Timer +} + +func newRawConnWrapperFromConn(cc io.ReadWriteCloser) *rawConnWrapper { + rcw := &rawConnWrapper{ + cc: cc, + frc: make(chan http2.Frame, 1), + frErrc: make(chan error, 1), + } + rcw.hpackEnc = hpack.NewEncoder(&rcw.headerBuf) + rcw.fr = http2.NewFramer(cc, cc) + rcw.fr.ReadMetaHeaders = hpack.NewDecoder(4096 /*initialHeaderTableSize*/, nil) + + return rcw +} + +func (rcw *rawConnWrapper) Close() error { + return rcw.cc.Close() +} + +func (rcw *rawConnWrapper) readFrame() (http2.Frame, error) { + go func() { + fr, err := rcw.fr.ReadFrame() + if err != nil { + rcw.frErrc <- err + } else { + rcw.frc <- fr + } + }() + t := time.NewTimer(2 * time.Second) + defer t.Stop() + select { + case f := <-rcw.frc: + return f, nil + case err := <-rcw.frErrc: + return nil, err + case <-t.C: + return nil, fmt.Errorf("timeout waiting for frame") + } +} + +// greet initiates the client's HTTP/2 connection into a state where +// frames may be sent. +func (rcw *rawConnWrapper) greet() error { + rcw.writePreface() + rcw.writeInitialSettings() + rcw.wantSettings() + rcw.writeSettingsAck() + for { + f, err := rcw.readFrame() + if err != nil { + return err + } + switch f := f.(type) { + case *http2.WindowUpdateFrame: + // grpc's transport/http2_server sends this + // before the settings ack. The Go http2 + // server uses a setting instead. + case *http2.SettingsFrame: + if f.IsAck() { + return nil + } + return fmt.Errorf("during greet, got non-ACK settings frame") + default: + return fmt.Errorf("during greet, unexpected frame type %T", f) + } + } +} + +func (rcw *rawConnWrapper) writePreface() error { + n, err := rcw.cc.Write([]byte(http2.ClientPreface)) + if err != nil { + return fmt.Errorf("Error writing client preface: %v", err) + } + if n != len(http2.ClientPreface) { + return fmt.Errorf("Writing client preface, wrote %d bytes; want %d", n, len(http2.ClientPreface)) + } + return nil +} + +func (rcw *rawConnWrapper) writeInitialSettings() error { + if err := rcw.fr.WriteSettings(); err != nil { + return fmt.Errorf("Error writing initial SETTINGS frame from client to server: %v", err) + } + return nil +} + +func (rcw *rawConnWrapper) writeSettingsAck() error { + if err := rcw.fr.WriteSettingsAck(); err != nil { + return fmt.Errorf("Error writing ACK of server's SETTINGS: %v", err) + } + return nil +} + +func (rcw *rawConnWrapper) wantSettings() (*http2.SettingsFrame, error) { + f, err := rcw.readFrame() + if err != nil { + return nil, fmt.Errorf("Error while expecting a SETTINGS frame: %v", err) + } + sf, ok := f.(*http2.SettingsFrame) + if !ok { + return nil, fmt.Errorf("got a %T; want *SettingsFrame", f) + } + return sf, nil +} + +func (rcw *rawConnWrapper) wantSettingsAck() error { + f, err := rcw.readFrame() + if err != nil { + return err + } + sf, ok := f.(*http2.SettingsFrame) + if !ok { + return fmt.Errorf("Wanting a settings ACK, received a %T", f) + } + if !sf.IsAck() { + return fmt.Errorf("Settings Frame didn't have ACK set") + } + return nil +} + +// wait for any activity from the server +func (rcw *rawConnWrapper) wantAnyFrame() (http2.Frame, error) { + f, err := rcw.fr.ReadFrame() + if err != nil { + return nil, err + } + return f, nil +} + +func (rcw *rawConnWrapper) encodeHeaderField(k, v string) error { + err := rcw.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) + if err != nil { + return fmt.Errorf("HPACK encoding error for %q/%q: %v", k, v, err) + } + return nil +} + +// encodeHeader encodes headers and returns their HPACK bytes. headers +// must contain an even number of key/value pairs. There may be +// multiple pairs for keys (e.g. "cookie"). The :method, :path, and +// :scheme headers default to GET, / and https. +func (rcw *rawConnWrapper) encodeHeader(headers ...string) []byte { + if len(headers)%2 == 1 { + panic("odd number of kv args") + } + + rcw.headerBuf.Reset() + + if len(headers) == 0 { + // Fast path, mostly for benchmarks, so test code doesn't pollute + // profiles when we're looking to improve server allocations. + rcw.encodeHeaderField(":method", "GET") + rcw.encodeHeaderField(":path", "/") + rcw.encodeHeaderField(":scheme", "https") + return rcw.headerBuf.Bytes() + } + + if len(headers) == 2 && headers[0] == ":method" { + // Another fast path for benchmarks. + rcw.encodeHeaderField(":method", headers[1]) + rcw.encodeHeaderField(":path", "/") + rcw.encodeHeaderField(":scheme", "https") + return rcw.headerBuf.Bytes() + } + + pseudoCount := map[string]int{} + keys := []string{":method", ":path", ":scheme"} + vals := map[string][]string{ + ":method": {"GET"}, + ":path": {"/"}, + ":scheme": {"https"}, + } + for len(headers) > 0 { + k, v := headers[0], headers[1] + headers = headers[2:] + if _, ok := vals[k]; !ok { + keys = append(keys, k) + } + if strings.HasPrefix(k, ":") { + pseudoCount[k]++ + if pseudoCount[k] == 1 { + vals[k] = []string{v} + } else { + // Allows testing of invalid headers w/ dup pseudo fields. + vals[k] = append(vals[k], v) + } + } else { + vals[k] = append(vals[k], v) + } + } + for _, k := range keys { + for _, v := range vals[k] { + rcw.encodeHeaderField(k, v) + } + } + return rcw.headerBuf.Bytes() +} + +func (rcw *rawConnWrapper) writeHeadersGRPC(streamID uint32, path string) { + rcw.writeHeaders(http2.HeadersFrameParam{ + StreamID: streamID, + BlockFragment: rcw.encodeHeader( + ":method", "POST", + ":path", path, + "content-type", "application/grpc", + "te", "trailers", + ), + EndStream: false, + EndHeaders: true, + }) +} + +func (rcw *rawConnWrapper) writeHeaders(p http2.HeadersFrameParam) error { + if err := rcw.fr.WriteHeaders(p); err != nil { + return fmt.Errorf("Error writing HEADERS: %v", err) + } + return nil +} + +func (rcw *rawConnWrapper) writeData(streamID uint32, endStream bool, data []byte) error { + if err := rcw.fr.WriteData(streamID, endStream, data); err != nil { + return fmt.Errorf("Error writing DATA: %v", err) + } + return nil +} + +func (rcw *rawConnWrapper) writeRSTStream(streamID uint32, code http2.ErrCode) error { + if err := rcw.fr.WriteRSTStream(streamID, code); err != nil { + return fmt.Errorf("Error writing RST_STREAM: %v", err) + } + return nil +} + +func (rcw *rawConnWrapper) writeDataPadded(streamID uint32, endStream bool, data, padding []byte) error { + if err := rcw.fr.WriteDataPadded(streamID, endStream, data, padding); err != nil { + return fmt.Errorf("Error writing DATA with padding: %v", err) + } + return nil +} + +func (rcw *rawConnWrapper) writeGoAway(maxStreamID uint32, code http2.ErrCode, debugData []byte) error { + if err := rcw.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { + return fmt.Errorf("Error writing GoAway: %v", err) + } + return nil +} + +func (rcw *rawConnWrapper) writeRawFrame(t http2.FrameType, flags http2.Flags, streamID uint32, payload []byte) error { + if err := rcw.fr.WriteRawFrame(t, flags, streamID, payload); err != nil { + return fmt.Errorf("Error writing Raw Frame: %v", err) + } + return nil +} diff --git a/transport/controlbuf.go b/transport/controlbuf.go index 6bcb9b7bf..72e025cc0 100644 --- a/transport/controlbuf.go +++ b/transport/controlbuf.go @@ -145,6 +145,10 @@ type ping struct { data [8]byte } +type outFlowControlSizeRequest struct { + resp chan uint32 +} + type outStreamState int const ( @@ -569,6 +573,11 @@ func (l *loopyWriter) pingHandler(p *ping) error { } +func (l *loopyWriter) outFlowControlSizeRequestHanlder(o *outFlowControlSizeRequest) error { + o.resp <- l.sendQuota + return nil +} + func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error { c.onWrite() if str, ok := l.estdStreams[c.streamID]; ok { @@ -633,6 +642,8 @@ func (l *loopyWriter) handle(i interface{}) error { return l.pingHandler(i) case *goAway: return l.goAwayHandler(i) + case *outFlowControlSizeRequest: + return l.outFlowControlSizeRequestHanlder(i) default: return fmt.Errorf("transport: unknown control message type %T", i) } diff --git a/transport/flowcontrol.go b/transport/flowcontrol.go index 5474e89af..378f5c450 100644 --- a/transport/flowcontrol.go +++ b/transport/flowcontrol.go @@ -96,13 +96,15 @@ func (w *writeQuota) replenish(n int) { } type trInFlow struct { - limit uint32 - unacked uint32 + limit uint32 + unacked uint32 + effectiveWindowSize uint32 } func (f *trInFlow) newLimit(n uint32) uint32 { d := n - f.limit f.limit = n + f.updateEffectiveWindowSize() return d } @@ -111,17 +113,28 @@ func (f *trInFlow) onData(n uint32) uint32 { if f.unacked >= f.limit/4 { w := f.unacked f.unacked = 0 + f.updateEffectiveWindowSize() return w } + f.updateEffectiveWindowSize() return 0 } func (f *trInFlow) reset() uint32 { w := f.unacked f.unacked = 0 + f.updateEffectiveWindowSize() return w } +func (f *trInFlow) updateEffectiveWindowSize() { + atomic.StoreUint32(&f.effectiveWindowSize, f.limit-f.unacked) +} + +func (f *trInFlow) getSize() uint32 { + return atomic.LoadUint32(&f.effectiveWindowSize) +} + // TODO(mmukhi): Simplify this code. // inFlow deals with inbound flow control type inFlow struct { diff --git a/transport/http2_client.go b/transport/http2_client.go index fe904e788..ab97cb571 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -104,7 +104,21 @@ type http2Client struct { // GoAway frame. goAwayReason GoAwayReason + // Fields below are for channelz metric collection. channelzID int64 // channelz unique identification number + czmu sync.RWMutex + kpCount int64 + // The number of streams that have started, including already finished ones. + streamsStarted int64 + // The number of streams that have ended successfully by receiving EoS bit set + // frame from server. + streamsSucceeded int64 + streamsFailed int64 + lastStreamCreated time.Time + msgSent int64 + msgRecv int64 + lastMsgSent time.Time + lastMsgRecv time.Time } func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { @@ -514,6 +528,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return false, err } t.activeStreams[id] = s + if channelz.IsOn() { + t.czmu.Lock() + t.streamsStarted++ + t.lastStreamCreated = time.Now() + t.czmu.Unlock() + } var sendPing bool // If the number of active streams change from 0 to 1, then check if keepalive // has gone dormant. If so, wake it up. @@ -604,10 +624,10 @@ func (t *http2Client) CloseStream(s *Stream, err error) { rst = true rstCode = http2.ErrCodeCancel } - t.closeStream(s, err, rst, rstCode, nil, nil) + t.closeStream(s, err, rst, rstCode, nil, nil, false) } -func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string) { +func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) { // Set stream status to done. if s.swapState(streamDone) == streamDone { // If it was already done, return. @@ -638,6 +658,15 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. delete(t.activeStreams, s.id) } t.mu.Unlock() + if channelz.IsOn() { + t.czmu.Lock() + if eosReceived { + t.streamsSucceeded++ + } else { + t.streamsFailed++ + } + t.czmu.Unlock() + } }, rst: rst, rstCode: rstCode, @@ -677,7 +706,7 @@ func (t *http2Client) Close() error { } // Notify all active streams. for _, s := range streams { - t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, nil, nil) + t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, nil, nil, false) } if t.statsHandler != nil { connEnd := &stats.ConnEnd{ @@ -832,7 +861,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { } if size > 0 { if err := s.fc.onData(size); err != nil { - t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil) + t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false) return } if f.Header().Flags.Has(http2.FlagDataPadded) { @@ -852,7 +881,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { // The server has closed the stream without sending trailers. Record that // the read direction is closed, and set the status appropriately. if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { - t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil) + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) } } @@ -870,7 +899,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { warningf("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error %v", f.ErrCode) statusCode = codes.Unknown } - t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil) + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode), nil, false) } func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) { @@ -974,7 +1003,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { if streamID > id && streamID <= upperLimit { // The stream was unprocessed by the server. atomic.StoreUint32(&stream.unprocessed, 1) - t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil) + t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) } } t.prevGoAwayID = id @@ -1022,7 +1051,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { var state decodeState if err := state.decodeResponseHeader(frame); err != nil { // TODO(mmukhi, dfawley): Perhaps send a reset stream. - t.closeStream(s, err, false, http2.ErrCodeNo, nil, nil) + t.closeStream(s, err, false, http2.ErrCodeNo, nil, nil, false) // Something wrong. Stops reading even when there is remaining. return } @@ -1064,7 +1093,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { if !endStream { return } - t.closeStream(s, io.EOF, false, http2.ErrCodeNo, state.status(), state.mdata) + t.closeStream(s, io.EOF, false, http2.ErrCodeNo, state.status(), state.mdata, true) } // reader runs as a separate goroutine in charge of reading data from network @@ -1105,7 +1134,7 @@ func (t *http2Client) reader() { if s != nil { // use error detail to provide better err message // TODO(mmukhi, dfawley): Perhaps send a RST_STREAM to the server. - t.closeStream(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.fr.ErrorDetail()), false, http2.ErrCodeNo, nil, nil) + t.closeStream(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.fr.ErrorDetail()), false, http2.ErrCodeNo, nil, nil, false) } continue } else { @@ -1161,6 +1190,11 @@ func (t *http2Client) keepalive() { } } else { t.mu.Unlock() + if channelz.IsOn() { + t.czmu.Lock() + t.kpCount++ + t.czmu.Unlock() + } // Send ping. t.controlBuf.put(p) } @@ -1199,8 +1233,54 @@ func (t *http2Client) GoAway() <-chan struct{} { } func (t *http2Client) ChannelzMetric() *channelz.SocketInternalMetric { - return &channelz.SocketInternalMetric{} + t.czmu.RLock() + s := channelz.SocketInternalMetric{ + StreamsStarted: t.streamsStarted, + StreamsSucceeded: t.streamsSucceeded, + StreamsFailed: t.streamsFailed, + MessagesSent: t.msgSent, + MessagesReceived: t.msgRecv, + KeepAlivesSent: t.kpCount, + LastLocalStreamCreatedTimestamp: t.lastStreamCreated, + LastMessageSentTimestamp: t.lastMsgSent, + LastMessageReceivedTimestamp: t.lastMsgRecv, + LocalFlowControlWindow: int64(t.fc.getSize()), + //socket options + LocalAddr: t.localAddr, + RemoteAddr: t.remoteAddr, + // Security + // RemoteName : + } + t.czmu.RUnlock() + s.RemoteFlowControlWindow = t.getOutFlowWindow() + return &s } -func (t *http2Client) IncrMsgSent() {} -func (t *http2Client) IncrMsgRecv() {} +func (t *http2Client) IncrMsgSent() { + t.czmu.Lock() + t.msgSent++ + t.lastMsgSent = time.Now() + t.czmu.Unlock() +} + +func (t *http2Client) IncrMsgRecv() { + t.czmu.Lock() + t.msgRecv++ + t.lastMsgRecv = time.Now() + t.czmu.Unlock() +} + +func (t *http2Client) getOutFlowWindow() int64 { + resp := make(chan uint32, 1) + timer := time.NewTimer(time.Second) + defer timer.Stop() + t.controlBuf.put(&outFlowControlSizeRequest{resp}) + select { + case sz := <-resp: + return int64(sz) + case <-t.ctxDone: + return -1 + case <-timer.C: + return -2 + } +} diff --git a/transport/http2_server.go b/transport/http2_server.go index aa996b46b..8b93e222e 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -92,8 +92,6 @@ type http2Server struct { initialWindowSize int32 bdpEst *bdpEstimator - channelzID int64 // channelz unique identification number - mu sync.Mutex // guard the following // drainChan is initialized when drain(...) is called the first time. @@ -110,6 +108,22 @@ type http2Server struct { // RPCs go down to 0. // When the connection is busy, this value is set to 0. idle time.Time + + // Fields below are for channelz metric collection. + channelzID int64 // channelz unique identification number + czmu sync.RWMutex + kpCount int64 + // The number of streams that have started, including already finished ones. + streamsStarted int64 + // The number of streams that have ended successfully by sending frame with + // EoS bit set. + streamsSucceeded int64 + streamsFailed int64 + lastStreamCreated time.Time + msgSent int64 + msgRecv int64 + lastMsgSent time.Time + lastMsgRecv time.Time } // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is @@ -295,7 +309,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( method: state.method, contentSubtype: state.contentSubtype, } - if frame.StreamEnded() { // s is just created by the caller. No lock needed. s.state = streamReadDone @@ -367,6 +380,12 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( t.idle = time.Time{} } t.mu.Unlock() + if channelz.IsOn() { + t.czmu.Lock() + t.streamsStarted++ + t.lastStreamCreated = time.Now() + t.czmu.Unlock() + } s.requestRead = func(n int) { t.adjustWindow(s, uint32(n)) } @@ -413,7 +432,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. s := t.activeStreams[se.StreamID] t.mu.Unlock() if s != nil { - t.closeStream(s, true, se.Code, nil) + t.closeStream(s, true, se.Code, nil, false) } else { t.controlBuf.put(&cleanupStream{ streamID: se.StreamID, @@ -555,7 +574,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { } if size > 0 { if err := s.fc.onData(size); err != nil { - t.closeStream(s, true, http2.ErrCodeFlowControl, nil) + t.closeStream(s, true, http2.ErrCodeFlowControl, nil, false) return } if f.Header().Flags.Has(http2.FlagDataPadded) { @@ -584,7 +603,7 @@ func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { if !ok { return } - t.closeStream(s, false, 0, nil) + t.closeStream(s, false, 0, nil, false) } func (t *http2Server) handleSettings(f *http2.SettingsFrame) { @@ -764,7 +783,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { atomic.StoreUint32(&t.resetPingStrikes, 1) }, } - t.closeStream(s, false, 0, trailer) + t.closeStream(s, false, 0, trailer, true) if t.stats != nil { t.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) } @@ -890,6 +909,11 @@ func (t *http2Server) keepalive() { return } pingSent = true + if channelz.IsOn() { + t.czmu.Lock() + t.kpCount++ + t.czmu.Unlock() + } t.controlBuf.put(p) keepalive.Reset(t.kp.Timeout) case <-t.ctx.Done(): @@ -930,7 +954,7 @@ func (t *http2Server) Close() error { // closeStream clears the footprint of a stream when the stream is not needed // any more. -func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame) { +func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) { if s.swapState(streamDone) == streamDone { // If the stream was already done, return. return @@ -952,6 +976,15 @@ func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, hd } } t.mu.Unlock() + if channelz.IsOn() { + t.czmu.Lock() + if eosReceived { + t.streamsSucceeded++ + } else { + t.streamsFailed++ + } + t.czmu.Unlock() + } }, } if hdr != nil { @@ -1038,11 +1071,57 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { } func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric { - return &channelz.SocketInternalMetric{} + t.czmu.RLock() + s := channelz.SocketInternalMetric{ + StreamsStarted: t.streamsStarted, + StreamsSucceeded: t.streamsSucceeded, + StreamsFailed: t.streamsFailed, + MessagesSent: t.msgSent, + MessagesReceived: t.msgRecv, + KeepAlivesSent: t.kpCount, + LastRemoteStreamCreatedTimestamp: t.lastStreamCreated, + LastMessageSentTimestamp: t.lastMsgSent, + LastMessageReceivedTimestamp: t.lastMsgRecv, + LocalFlowControlWindow: int64(t.fc.getSize()), + //socket options + LocalAddr: t.localAddr, + RemoteAddr: t.remoteAddr, + // Security + // RemoteName : + } + t.czmu.RUnlock() + s.RemoteFlowControlWindow = t.getOutFlowWindow() + return &s } -func (t *http2Server) IncrMsgSent() {} -func (t *http2Server) IncrMsgRecv() {} +func (t *http2Server) IncrMsgSent() { + t.czmu.Lock() + t.msgSent++ + t.lastMsgSent = time.Now() + t.czmu.Unlock() +} + +func (t *http2Server) IncrMsgRecv() { + t.czmu.Lock() + t.msgRecv++ + t.lastMsgRecv = time.Now() + t.czmu.Unlock() +} + +func (t *http2Server) getOutFlowWindow() int64 { + resp := make(chan uint32) + timer := time.NewTimer(time.Second) + defer timer.Stop() + t.controlBuf.put(&outFlowControlSizeRequest{resp}) + select { + case sz := <-resp: + return int64(sz) + case <-t.ctxDone: + return -1 + case <-timer.C: + return -2 + } +} var rgen = rand.New(rand.NewSource(time.Now().UnixNano())) diff --git a/transport/transport_test.go b/transport/transport_test.go index 46029b0a9..e0201efa8 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -478,7 +478,7 @@ func TestMaxConnectionIdle(t *testing.T) { if err != nil { t.Fatalf("Client failed to create RPC request: %v", err) } - client.(*http2Client).closeStream(stream, io.EOF, true, http2.ErrCodeCancel, nil, nil) + client.(*http2Client).closeStream(stream, io.EOF, true, http2.ErrCodeCancel, nil, nil, false) // wait for server to see that closed stream and max-age logic to send goaway after no new RPCs are mode timeout := time.NewTimer(time.Second * 4) select {