mirror of https://github.com/grpc/grpc-go.git
				
				
				
			Client load report for grpclb. (#1200)
This commit is contained in:
		
							parent
							
								
									a7fee9febf
								
							
						
					
					
						commit
						277e90a432
					
				
							
								
								
									
										56
									
								
								call.go
								
								
								
								
							
							
						
						
									
										56
									
								
								call.go
								
								
								
								
							|  | @ -93,11 +93,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran | |||
| } | ||||
| 
 | ||||
| // sendRequest writes out various information of an RPC such as Context and Message.
 | ||||
| func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { | ||||
| 	stream, err := t.NewStream(ctx, callHdr) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, stream *transport.Stream, t transport.ClientTransport, args interface{}, opts *transport.Options) (err error) { | ||||
| 	defer func() { | ||||
| 		if err != nil { | ||||
| 			// If err is connection error, t will be closed, no need to close stream here.
 | ||||
|  | @ -120,7 +116,7 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, | |||
| 	} | ||||
| 	outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload) | ||||
| 	if err != nil { | ||||
| 		return nil, Errorf(codes.Internal, "grpc: %v", err) | ||||
| 		return Errorf(codes.Internal, "grpc: %v", err) | ||||
| 	} | ||||
| 	err = t.Write(stream, outBuf, opts) | ||||
| 	if err == nil && outPayload != nil { | ||||
|  | @ -131,10 +127,10 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, | |||
| 	// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
 | ||||
| 	// recvResponse to get the final status.
 | ||||
| 	if err != nil && err != io.EOF { | ||||
| 		return nil, err | ||||
| 		return err | ||||
| 	} | ||||
| 	// Sent successfully.
 | ||||
| 	return stream, nil | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Invoke sends the RPC request on the wire and returns after response is received.
 | ||||
|  | @ -183,6 +179,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli | |||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
| 	ctx = newContextWithRPCInfo(ctx) | ||||
| 	sh := cc.dopts.copts.StatsHandler | ||||
| 	if sh != nil { | ||||
| 		ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) | ||||
|  | @ -246,19 +243,35 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli | |||
| 		if c.traceInfo.tr != nil { | ||||
| 			c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) | ||||
| 		} | ||||
| 		stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, t, args, topts) | ||||
| 		stream, err = t.NewStream(ctx, callHdr) | ||||
| 		if err != nil { | ||||
| 			if put != nil { | ||||
| 				if _, ok := err.(transport.ConnectionError); ok { | ||||
| 					// If error is connection error, transport was sending data on wire,
 | ||||
| 					// and we are not sure if anything has been sent on wire.
 | ||||
| 					// If error is not connection error, we are sure nothing has been sent.
 | ||||
| 					updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false}) | ||||
| 				} | ||||
| 				put() | ||||
| 			} | ||||
| 			if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { | ||||
| 				continue | ||||
| 			} | ||||
| 			return toRPCErr(err) | ||||
| 		} | ||||
| 		err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, stream, t, args, topts) | ||||
| 		if err != nil { | ||||
| 			if put != nil { | ||||
| 				updateRPCInfoInContext(ctx, rpcInfo{ | ||||
| 					bytesSent:     stream.BytesSent(), | ||||
| 					bytesReceived: stream.BytesReceived(), | ||||
| 				}) | ||||
| 				put() | ||||
| 				put = nil | ||||
| 			} | ||||
| 			// Retry a non-failfast RPC when
 | ||||
| 			// i) there is a connection error; or
 | ||||
| 			// ii) the server started to drain before this RPC was initiated.
 | ||||
| 			if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { | ||||
| 				if c.failFast { | ||||
| 					return toRPCErr(err) | ||||
| 				} | ||||
| 			if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { | ||||
| 				continue | ||||
| 			} | ||||
| 			return toRPCErr(err) | ||||
|  | @ -266,13 +279,13 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli | |||
| 		err = recvResponse(ctx, cc.dopts, t, &c, stream, reply) | ||||
| 		if err != nil { | ||||
| 			if put != nil { | ||||
| 				updateRPCInfoInContext(ctx, rpcInfo{ | ||||
| 					bytesSent:     stream.BytesSent(), | ||||
| 					bytesReceived: stream.BytesReceived(), | ||||
| 				}) | ||||
| 				put() | ||||
| 				put = nil | ||||
| 			} | ||||
| 			if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { | ||||
| 				if c.failFast { | ||||
| 					return toRPCErr(err) | ||||
| 				} | ||||
| 			if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { | ||||
| 				continue | ||||
| 			} | ||||
| 			return toRPCErr(err) | ||||
|  | @ -282,8 +295,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli | |||
| 		} | ||||
| 		t.CloseStream(stream, nil) | ||||
| 		if put != nil { | ||||
| 			updateRPCInfoInContext(ctx, rpcInfo{ | ||||
| 				bytesSent:     stream.BytesSent(), | ||||
| 				bytesReceived: stream.BytesReceived(), | ||||
| 			}) | ||||
| 			put() | ||||
| 			put = nil | ||||
| 		} | ||||
| 		return stream.Status().Err() | ||||
| 	} | ||||
|  |  | |||
|  | @ -669,6 +669,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) | |||
| 	} | ||||
| 	if !ok { | ||||
| 		if put != nil { | ||||
| 			updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false}) | ||||
| 			put() | ||||
| 		} | ||||
| 		return nil, nil, errConnClosing | ||||
|  | @ -676,6 +677,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) | |||
| 	t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait) | ||||
| 	if err != nil { | ||||
| 		if put != nil { | ||||
| 			updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false}) | ||||
| 			put() | ||||
| 		} | ||||
| 		return nil, nil, err | ||||
|  |  | |||
							
								
								
									
										85
									
								
								grpclb.go
								
								
								
								
							
							
						
						
									
										85
									
								
								grpclb.go
								
								
								
								
							|  | @ -145,6 +145,8 @@ type balancer struct { | |||
| 	done     bool | ||||
| 	expTimer *time.Timer | ||||
| 	rand     *rand.Rand | ||||
| 
 | ||||
| 	clientStats lbpb.ClientStats | ||||
| } | ||||
| 
 | ||||
| func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error { | ||||
|  | @ -281,6 +283,34 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { | |||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (b *balancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) { | ||||
| 	ticker := time.NewTicker(interval) | ||||
| 	defer ticker.Stop() | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-ticker.C: | ||||
| 		case <-done: | ||||
| 			return | ||||
| 		} | ||||
| 		b.mu.Lock() | ||||
| 		stats := b.clientStats | ||||
| 		b.clientStats = lbpb.ClientStats{} // Clear the stats.
 | ||||
| 		b.mu.Unlock() | ||||
| 		t := time.Now() | ||||
| 		stats.Timestamp = &lbpb.Timestamp{ | ||||
| 			Seconds: t.Unix(), | ||||
| 			Nanos:   int32(t.Nanosecond()), | ||||
| 		} | ||||
| 		if err := s.Send(&lbpb.LoadBalanceRequest{ | ||||
| 			LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{ | ||||
| 				ClientStats: &stats, | ||||
| 			}, | ||||
| 		}); err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	defer cancel() | ||||
|  | @ -322,6 +352,14 @@ func (b *balancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry b | |||
| 		grpclog.Println("TODO: Delegation is not supported yet.") | ||||
| 		return | ||||
| 	} | ||||
| 	streamDone := make(chan struct{}) | ||||
| 	defer close(streamDone) | ||||
| 	b.mu.Lock() | ||||
| 	b.clientStats = lbpb.ClientStats{} // Clear client stats.
 | ||||
| 	b.mu.Unlock() | ||||
| 	if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 { | ||||
| 		go b.sendLoadReport(stream, d, streamDone) | ||||
| 	} | ||||
| 	// Retrieve the server list.
 | ||||
| 	for { | ||||
| 		reply, err := stream.Recv() | ||||
|  | @ -538,7 +576,32 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre | |||
| 		err = ErrClientConnClosing | ||||
| 		return | ||||
| 	} | ||||
| 	seq := b.seq | ||||
| 
 | ||||
| 	defer func() { | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 		put = func() { | ||||
| 			s, ok := rpcInfoFromContext(ctx) | ||||
| 			if !ok { | ||||
| 				return | ||||
| 			} | ||||
| 			b.mu.Lock() | ||||
| 			defer b.mu.Unlock() | ||||
| 			if b.done || seq < b.seq { | ||||
| 				return | ||||
| 			} | ||||
| 			b.clientStats.NumCallsFinished++ | ||||
| 			if !s.bytesSent { | ||||
| 				b.clientStats.NumCallsFinishedWithClientFailedToSend++ | ||||
| 			} else if s.bytesReceived { | ||||
| 				b.clientStats.NumCallsFinishedKnownReceived++ | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	b.clientStats.NumCallsStarted++ | ||||
| 	if len(b.addrs) > 0 { | ||||
| 		if b.next >= len(b.addrs) { | ||||
| 			b.next = 0 | ||||
|  | @ -556,6 +619,13 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre | |||
| 				} | ||||
| 				if !opts.BlockingWait { | ||||
| 					b.next = next | ||||
| 					if a.dropForLoadBalancing { | ||||
| 						b.clientStats.NumCallsFinished++ | ||||
| 						b.clientStats.NumCallsFinishedWithDropForLoadBalancing++ | ||||
| 					} else if a.dropForRateLimiting { | ||||
| 						b.clientStats.NumCallsFinished++ | ||||
| 						b.clientStats.NumCallsFinishedWithDropForRateLimiting++ | ||||
| 					} | ||||
| 					b.mu.Unlock() | ||||
| 					err = Errorf(codes.Unavailable, "%s drops requests", a.addr.Addr) | ||||
| 					return | ||||
|  | @ -569,6 +639,8 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre | |||
| 	} | ||||
| 	if !opts.BlockingWait { | ||||
| 		if len(b.addrs) == 0 { | ||||
| 			b.clientStats.NumCallsFinished++ | ||||
| 			b.clientStats.NumCallsFinishedWithClientFailedToSend++ | ||||
| 			b.mu.Unlock() | ||||
| 			err = Errorf(codes.Unavailable, "there is no address available") | ||||
| 			return | ||||
|  | @ -590,11 +662,17 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre | |||
| 	for { | ||||
| 		select { | ||||
| 		case <-ctx.Done(): | ||||
| 			b.mu.Lock() | ||||
| 			b.clientStats.NumCallsFinished++ | ||||
| 			b.clientStats.NumCallsFinishedWithClientFailedToSend++ | ||||
| 			b.mu.Unlock() | ||||
| 			err = ctx.Err() | ||||
| 			return | ||||
| 		case <-ch: | ||||
| 			b.mu.Lock() | ||||
| 			if b.done { | ||||
| 				b.clientStats.NumCallsFinished++ | ||||
| 				b.clientStats.NumCallsFinishedWithClientFailedToSend++ | ||||
| 				b.mu.Unlock() | ||||
| 				err = ErrClientConnClosing | ||||
| 				return | ||||
|  | @ -617,6 +695,13 @@ func (b *balancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Addre | |||
| 						} | ||||
| 						if !opts.BlockingWait { | ||||
| 							b.next = next | ||||
| 							if a.dropForLoadBalancing { | ||||
| 								b.clientStats.NumCallsFinished++ | ||||
| 								b.clientStats.NumCallsFinishedWithDropForLoadBalancing++ | ||||
| 							} else if a.dropForRateLimiting { | ||||
| 								b.clientStats.NumCallsFinished++ | ||||
| 								b.clientStats.NumCallsFinishedWithDropForRateLimiting++ | ||||
| 							} | ||||
| 							b.mu.Unlock() | ||||
| 							err = Errorf(codes.Unavailable, "drop requests for the addreess %s", a.addr.Addr) | ||||
| 							return | ||||
|  |  | |||
|  | @ -40,6 +40,7 @@ import ( | |||
| 	"net" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
|  | @ -47,10 +48,10 @@ import ( | |||
| 	"google.golang.org/grpc" | ||||
| 	"google.golang.org/grpc/codes" | ||||
| 	"google.golang.org/grpc/credentials" | ||||
| 	hwpb "google.golang.org/grpc/examples/helloworld/helloworld" | ||||
| 	lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1" | ||||
| 	"google.golang.org/grpc/metadata" | ||||
| 	"google.golang.org/grpc/naming" | ||||
| 	testpb "google.golang.org/grpc/test/grpc_testing" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
|  | @ -172,7 +173,10 @@ func (c *serverNameCheckCreds) OverrideServerName(s string) error { | |||
| type remoteBalancer struct { | ||||
| 	sls       []*lbpb.ServerList | ||||
| 	intervals []time.Duration | ||||
| 	statsDura time.Duration | ||||
| 	done      chan struct{} | ||||
| 	mu        sync.Mutex | ||||
| 	stats     lbpb.ClientStats | ||||
| } | ||||
| 
 | ||||
| func newRemoteBalancer(sls []*lbpb.ServerList, intervals []time.Duration) *remoteBalancer { | ||||
|  | @ -198,12 +202,36 @@ func (b *remoteBalancer) BalanceLoad(stream *loadBalancerBalanceLoadServer) erro | |||
| 	} | ||||
| 	resp := &lbpb.LoadBalanceResponse{ | ||||
| 		LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{ | ||||
| 			InitialResponse: new(lbpb.InitialLoadBalanceResponse), | ||||
| 			InitialResponse: &lbpb.InitialLoadBalanceResponse{ | ||||
| 				ClientStatsReportInterval: &lbpb.Duration{ | ||||
| 					Seconds: int64(b.statsDura.Seconds()), | ||||
| 					Nanos:   int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9), | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	if err := stream.Send(resp); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			var ( | ||||
| 				req *lbpb.LoadBalanceRequest | ||||
| 				err error | ||||
| 			) | ||||
| 			if req, err = stream.Recv(); err != nil { | ||||
| 				return | ||||
| 			} | ||||
| 			b.mu.Lock() | ||||
| 			b.stats.NumCallsStarted += req.GetClientStats().NumCallsStarted | ||||
| 			b.stats.NumCallsFinished += req.GetClientStats().NumCallsFinished | ||||
| 			b.stats.NumCallsFinishedWithDropForRateLimiting += req.GetClientStats().NumCallsFinishedWithDropForRateLimiting | ||||
| 			b.stats.NumCallsFinishedWithDropForLoadBalancing += req.GetClientStats().NumCallsFinishedWithDropForLoadBalancing | ||||
| 			b.stats.NumCallsFinishedWithClientFailedToSend += req.GetClientStats().NumCallsFinishedWithClientFailedToSend | ||||
| 			b.stats.NumCallsFinishedKnownReceived += req.GetClientStats().NumCallsFinishedKnownReceived | ||||
| 			b.mu.Unlock() | ||||
| 		} | ||||
| 	}() | ||||
| 	for k, v := range b.sls { | ||||
| 		time.Sleep(b.intervals[k]) | ||||
| 		resp = &lbpb.LoadBalanceResponse{ | ||||
|  | @ -219,11 +247,15 @@ func (b *remoteBalancer) BalanceLoad(stream *loadBalancerBalanceLoadServer) erro | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type helloServer struct { | ||||
| type testServer struct { | ||||
| 	testpb.TestServiceServer | ||||
| 
 | ||||
| 	addr string | ||||
| } | ||||
| 
 | ||||
| func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) { | ||||
| const testmdkey = "testmd" | ||||
| 
 | ||||
| func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { | ||||
| 	md, ok := metadata.FromIncomingContext(ctx) | ||||
| 	if !ok { | ||||
| 		return nil, grpc.Errorf(codes.Internal, "failed to receive metadata") | ||||
|  | @ -231,9 +263,12 @@ func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwp | |||
| 	if md == nil || md["lb-token"][0] != lbToken { | ||||
| 		return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md) | ||||
| 	} | ||||
| 	return &hwpb.HelloReply{ | ||||
| 		Message: "Hello " + in.Name + " for " + s.addr, | ||||
| 	}, nil | ||||
| 	grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr)) | ||||
| 	return &testpb.Empty{}, nil | ||||
| } | ||||
| 
 | ||||
| func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) { | ||||
|  | @ -242,7 +277,7 @@ func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) { | |||
| 			sn: sn, | ||||
| 		} | ||||
| 		s := grpc.NewServer(grpc.Creds(creds)) | ||||
| 		hwpb.RegisterGreeterServer(s, &helloServer{addr: l.Addr().String()}) | ||||
| 		testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String()}) | ||||
| 		servers = append(servers, s) | ||||
| 		go func(s *grpc.Server, l net.Listener) { | ||||
| 			s.Serve(l) | ||||
|  | @ -356,9 +391,9 @@ func TestGRPCLB(t *testing.T) { | |||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to dial to the backend %v", err) | ||||
| 	} | ||||
| 	helloC := hwpb.NewGreeterClient(cc) | ||||
| 	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { | ||||
| 		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err) | ||||
| 	testC := testpb.NewTestServiceClient(cc) | ||||
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { | ||||
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) | ||||
| 	} | ||||
| 	cc.Close() | ||||
| } | ||||
|  | @ -393,22 +428,22 @@ func TestDropRequest(t *testing.T) { | |||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to dial to the backend %v", err) | ||||
| 	} | ||||
| 	helloC := hwpb.NewGreeterClient(cc) | ||||
| 	testC := testpb.NewTestServiceClient(cc) | ||||
| 	// The 1st, non-fail-fast RPC should succeed.  This ensures both server
 | ||||
| 	// connections are made, because the first one has DropForLoadBalancing set to true.
 | ||||
| 	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil { | ||||
| 		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err) | ||||
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { | ||||
| 		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err) | ||||
| 	} | ||||
| 	for i := 0; i < 3; i++ { | ||||
| 		// Odd fail-fast RPCs should fail, because the 1st backend has DropForLoadBalancing
 | ||||
| 		// set to true.
 | ||||
| 		if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable { | ||||
| 			t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable) | ||||
| 		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { | ||||
| 			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable) | ||||
| 		} | ||||
| 		// Even fail-fast RPCs should succeed since they choose the
 | ||||
| 		// non-drop-request backend according to the round robin policy.
 | ||||
| 		if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { | ||||
| 			t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err) | ||||
| 		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { | ||||
| 			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) | ||||
| 		} | ||||
| 	} | ||||
| 	cc.Close() | ||||
|  | @ -443,10 +478,10 @@ func TestDropRequestFailedNonFailFast(t *testing.T) { | |||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to dial to the backend %v", err) | ||||
| 	} | ||||
| 	helloC := hwpb.NewGreeterClient(cc) | ||||
| 	testC := testpb.NewTestServiceClient(cc) | ||||
| 	ctx, _ = context.WithTimeout(context.Background(), 10*time.Millisecond) | ||||
| 	if _, err := helloC.SayHello(ctx, &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { | ||||
| 		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.DeadlineExceeded) | ||||
| 	if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { | ||||
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded) | ||||
| 	} | ||||
| 	cc.Close() | ||||
| } | ||||
|  | @ -493,19 +528,19 @@ func TestServerExpiration(t *testing.T) { | |||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to dial to the backend %v", err) | ||||
| 	} | ||||
| 	helloC := hwpb.NewGreeterClient(cc) | ||||
| 	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { | ||||
| 		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err) | ||||
| 	testC := testpb.NewTestServiceClient(cc) | ||||
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { | ||||
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) | ||||
| 	} | ||||
| 	// Sleep and wake up when the first server list gets expired.
 | ||||
| 	time.Sleep(150 * time.Millisecond) | ||||
| 	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable { | ||||
| 		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable) | ||||
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { | ||||
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable) | ||||
| 	} | ||||
| 	// A non-failfast rpc should be succeeded after the second server list is received from
 | ||||
| 	// the remote load balancer.
 | ||||
| 	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil { | ||||
| 		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err) | ||||
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { | ||||
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) | ||||
| 	} | ||||
| 	cc.Close() | ||||
| } | ||||
|  | @ -551,23 +586,24 @@ func TestBalancerDisconnects(t *testing.T) { | |||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to dial to the backend %v", err) | ||||
| 	} | ||||
| 	helloC := hwpb.NewGreeterClient(cc) | ||||
| 	var message string | ||||
| 	if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { | ||||
| 		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err) | ||||
| 	testC := testpb.NewTestServiceClient(cc) | ||||
| 	var previousTrailer string | ||||
| 	trailer := metadata.MD{} | ||||
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil { | ||||
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) | ||||
| 	} else { | ||||
| 		message = resp.Message | ||||
| 		previousTrailer = trailer[testmdkey][0] | ||||
| 	} | ||||
| 	// The initial resolver update contains lbs[0] and lbs[1].
 | ||||
| 	// When lbs[0] is stopped, lbs[1] should be used.
 | ||||
| 	lbs[0].Stop() | ||||
| 	for { | ||||
| 		if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { | ||||
| 			t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err) | ||||
| 		} else if resp.Message != message { | ||||
| 		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil { | ||||
| 			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) | ||||
| 		} else if trailer[testmdkey][0] != previousTrailer { | ||||
| 			// A new backend server should receive the request.
 | ||||
| 			// The response contains the backend address, so the message should be different from the previous one.
 | ||||
| 			message = resp.Message | ||||
| 			// The trailer contains the backend address, so the trailer should be different from the previous one.
 | ||||
| 			previousTrailer = trailer[testmdkey][0] | ||||
| 			break | ||||
| 		} | ||||
| 		time.Sleep(100 * time.Millisecond) | ||||
|  | @ -585,14 +621,194 @@ func TestBalancerDisconnects(t *testing.T) { | |||
| 	// Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used.
 | ||||
| 	lbs[1].Stop() | ||||
| 	for { | ||||
| 		if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { | ||||
| 			t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err) | ||||
| 		} else if resp.Message != message { | ||||
| 		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer)); err != nil { | ||||
| 			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) | ||||
| 		} else if trailer[testmdkey][0] != previousTrailer { | ||||
| 			// A new backend server should receive the request.
 | ||||
| 			// The response contains the backend address, so the message should be different from the previous one.
 | ||||
| 			// The trailer contains the backend address, so the trailer should be different from the previous one.
 | ||||
| 			break | ||||
| 		} | ||||
| 		time.Sleep(100 * time.Millisecond) | ||||
| 	} | ||||
| 	cc.Close() | ||||
| } | ||||
| 
 | ||||
| type failPreRPCCred struct{} | ||||
| 
 | ||||
| func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { | ||||
| 	if strings.Contains(uri[0], "failtosend") { | ||||
| 		return nil, fmt.Errorf("rpc should fail to send") | ||||
| 	} | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| func (failPreRPCCred) RequireTransportSecurity() bool { | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func TestGRPCLBStatsUnary(t *testing.T) { | ||||
| 	var ( | ||||
| 		countNormalRPC    = 66 // 1/3 succeeds, 1/3 dropped load balancing, 1/3 dropped rate limiting.
 | ||||
| 		countFailedToSend = 30 // 1/3 fail to send, 1/3 dropped load balancing, 1/3 dropped rate limiting.
 | ||||
| 	) | ||||
| 	tss, cleanup, err := newLoadBalancer(3) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to create new load balancer: %v", err) | ||||
| 	} | ||||
| 	defer cleanup() | ||||
| 	tss.ls.sls = []*lbpb.ServerList{{ | ||||
| 		Servers: []*lbpb.Server{{ | ||||
| 			IpAddress:            tss.beIPs[0], | ||||
| 			Port:                 int32(tss.bePorts[0]), | ||||
| 			LoadBalanceToken:     lbToken, | ||||
| 			DropForLoadBalancing: true, | ||||
| 		}, { | ||||
| 			IpAddress:           tss.beIPs[1], | ||||
| 			Port:                int32(tss.bePorts[1]), | ||||
| 			LoadBalanceToken:    lbToken, | ||||
| 			DropForRateLimiting: true, | ||||
| 		}, { | ||||
| 			IpAddress:            tss.beIPs[2], | ||||
| 			Port:                 int32(tss.bePorts[2]), | ||||
| 			LoadBalanceToken:     lbToken, | ||||
| 			DropForLoadBalancing: false, | ||||
| 		}}, | ||||
| 	}} | ||||
| 	tss.ls.intervals = []time.Duration{0} | ||||
| 	tss.ls.statsDura = 100 * time.Millisecond | ||||
| 	creds := serverNameCheckCreds{ | ||||
| 		expected: besn, | ||||
| 	} | ||||
| 	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) | ||||
| 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ | ||||
| 		addrs: []string{tss.lbAddr}, | ||||
| 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{})) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to dial to the backend %v", err) | ||||
| 	} | ||||
| 	testC := testpb.NewTestServiceClient(cc) | ||||
| 	// The first non-failfast RPC succeeds, all connections are up.
 | ||||
| 	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { | ||||
| 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) | ||||
| 	} | ||||
| 	for i := 0; i < countNormalRPC-1; i++ { | ||||
| 		testC.EmptyCall(context.Background(), &testpb.Empty{}) | ||||
| 	} | ||||
| 	for i := 0; i < countFailedToSend; i++ { | ||||
| 		grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc) | ||||
| 	} | ||||
| 	cc.Close() | ||||
| 
 | ||||
| 	time.Sleep(1 * time.Second) | ||||
| 	tss.ls.mu.Lock() | ||||
| 	if tss.ls.stats.NumCallsStarted != int64(countNormalRPC+countFailedToSend) { | ||||
| 		t.Errorf("num calls started = %v, want %v+%v", tss.ls.stats.NumCallsStarted, countNormalRPC, countFailedToSend) | ||||
| 	} | ||||
| 	if tss.ls.stats.NumCallsFinished != int64(countNormalRPC+countFailedToSend) { | ||||
| 		t.Errorf("num calls finished = %v, want %v+%v", tss.ls.stats.NumCallsFinished, countNormalRPC, countFailedToSend) | ||||
| 	} | ||||
| 	if tss.ls.stats.NumCallsFinishedWithDropForRateLimiting != int64(countNormalRPC+countFailedToSend)/3 { | ||||
| 		t.Errorf("num calls drop rate limiting = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForRateLimiting, countNormalRPC, countFailedToSend) | ||||
| 	} | ||||
| 	if tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing != int64(countNormalRPC+countFailedToSend)/3 { | ||||
| 		t.Errorf("num calls drop load balancing = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing, countNormalRPC, countFailedToSend) | ||||
| 	} | ||||
| 	if tss.ls.stats.NumCallsFinishedWithClientFailedToSend != int64(countFailedToSend)/3 { | ||||
| 		t.Errorf("num calls failed to send = %v, want %v/3", tss.ls.stats.NumCallsFinishedWithClientFailedToSend, countFailedToSend) | ||||
| 	} | ||||
| 	if tss.ls.stats.NumCallsFinishedKnownReceived != int64(countNormalRPC)/3 { | ||||
| 		t.Errorf("num calls known received = %v, want %v/3", tss.ls.stats.NumCallsFinishedKnownReceived, countNormalRPC) | ||||
| 	} | ||||
| 	tss.ls.mu.Unlock() | ||||
| } | ||||
| 
 | ||||
| func TestGRPCLBStatsStreaming(t *testing.T) { | ||||
| 	var ( | ||||
| 		countNormalRPC    = 66 // 1/3 succeeds, 1/3 dropped load balancing, 1/3 dropped rate limiting.
 | ||||
| 		countFailedToSend = 30 // 1/3 fail to send, 1/3 dropped load balancing, 1/3 dropped rate limiting.
 | ||||
| 	) | ||||
| 	tss, cleanup, err := newLoadBalancer(3) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to create new load balancer: %v", err) | ||||
| 	} | ||||
| 	defer cleanup() | ||||
| 	tss.ls.sls = []*lbpb.ServerList{{ | ||||
| 		Servers: []*lbpb.Server{{ | ||||
| 			IpAddress:            tss.beIPs[0], | ||||
| 			Port:                 int32(tss.bePorts[0]), | ||||
| 			LoadBalanceToken:     lbToken, | ||||
| 			DropForLoadBalancing: true, | ||||
| 		}, { | ||||
| 			IpAddress:           tss.beIPs[1], | ||||
| 			Port:                int32(tss.bePorts[1]), | ||||
| 			LoadBalanceToken:    lbToken, | ||||
| 			DropForRateLimiting: true, | ||||
| 		}, { | ||||
| 			IpAddress:            tss.beIPs[2], | ||||
| 			Port:                 int32(tss.bePorts[2]), | ||||
| 			LoadBalanceToken:     lbToken, | ||||
| 			DropForLoadBalancing: false, | ||||
| 		}}, | ||||
| 	}} | ||||
| 	tss.ls.intervals = []time.Duration{0} | ||||
| 	tss.ls.statsDura = 100 * time.Millisecond | ||||
| 	creds := serverNameCheckCreds{ | ||||
| 		expected: besn, | ||||
| 	} | ||||
| 	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) | ||||
| 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ | ||||
| 		addrs: []string{tss.lbAddr}, | ||||
| 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{})) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Failed to dial to the backend %v", err) | ||||
| 	} | ||||
| 	testC := testpb.NewTestServiceClient(cc) | ||||
| 	// The first non-failfast RPC succeeds, all connections are up.
 | ||||
| 	var stream testpb.TestService_FullDuplexCallClient | ||||
| 	stream, err = testC.FullDuplexCall(context.Background(), grpc.FailFast(false)) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err) | ||||
| 	} | ||||
| 	for { | ||||
| 		if _, err = stream.Recv(); err == io.EOF { | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	for i := 0; i < countNormalRPC-1; i++ { | ||||
| 		stream, err = testC.FullDuplexCall(context.Background()) | ||||
| 		if err == nil { | ||||
| 			// Wait for stream to end if err is nil.
 | ||||
| 			for { | ||||
| 				if _, err = stream.Recv(); err == io.EOF { | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	for i := 0; i < countFailedToSend; i++ { | ||||
| 		grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend") | ||||
| 	} | ||||
| 	cc.Close() | ||||
| 
 | ||||
| 	time.Sleep(1 * time.Second) | ||||
| 	tss.ls.mu.Lock() | ||||
| 	if tss.ls.stats.NumCallsStarted != int64(countNormalRPC+countFailedToSend) { | ||||
| 		t.Errorf("num calls started = %v, want %v+%v", tss.ls.stats.NumCallsStarted, countNormalRPC, countFailedToSend) | ||||
| 	} | ||||
| 	if tss.ls.stats.NumCallsFinished != int64(countNormalRPC+countFailedToSend) { | ||||
| 		t.Errorf("num calls finished = %v, want %v+%v", tss.ls.stats.NumCallsFinished, countNormalRPC, countFailedToSend) | ||||
| 	} | ||||
| 	if tss.ls.stats.NumCallsFinishedWithDropForRateLimiting != int64(countNormalRPC+countFailedToSend)/3 { | ||||
| 		t.Errorf("num calls drop rate limiting = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForRateLimiting, countNormalRPC, countFailedToSend) | ||||
| 	} | ||||
| 	if tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing != int64(countNormalRPC+countFailedToSend)/3 { | ||||
| 		t.Errorf("num calls drop load balancing = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing, countNormalRPC, countFailedToSend) | ||||
| 	} | ||||
| 	if tss.ls.stats.NumCallsFinishedWithClientFailedToSend != int64(countFailedToSend)/3 { | ||||
| 		t.Errorf("num calls failed to send = %v, want %v/3", tss.ls.stats.NumCallsFinishedWithClientFailedToSend, countFailedToSend) | ||||
| 	} | ||||
| 	if tss.ls.stats.NumCallsFinishedKnownReceived != int64(countNormalRPC)/3 { | ||||
| 		t.Errorf("num calls known received = %v, want %v/3", tss.ls.stats.NumCallsFinishedKnownReceived, countNormalRPC) | ||||
| 	} | ||||
| 	tss.ls.mu.Unlock() | ||||
| } | ||||
|  |  | |||
							
								
								
									
										23
									
								
								rpc_util.go
								
								
								
								
							
							
						
						
									
										23
									
								
								rpc_util.go
								
								
								
								
							|  | @ -345,6 +345,29 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type rpcInfo struct { | ||||
| 	bytesSent     bool | ||||
| 	bytesReceived bool | ||||
| } | ||||
| 
 | ||||
| type rpcInfoContextKey struct{} | ||||
| 
 | ||||
| func newContextWithRPCInfo(ctx context.Context) context.Context { | ||||
| 	return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{}) | ||||
| } | ||||
| 
 | ||||
| func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) { | ||||
| 	s, ok = ctx.Value(rpcInfoContextKey{}).(*rpcInfo) | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func updateRPCInfoInContext(ctx context.Context, s rpcInfo) { | ||||
| 	if ss, ok := rpcInfoFromContext(ctx); ok { | ||||
| 		*ss = s | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // Code returns the error code for err if it was produced by the rpc system.
 | ||||
| // Otherwise, it returns codes.Unknown.
 | ||||
| //
 | ||||
|  |  | |||
							
								
								
									
										16
									
								
								stream.go
								
								
								
								
							
							
						
						
									
										16
									
								
								stream.go
								
								
								
								
							|  | @ -151,6 +151,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth | |||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
| 	ctx = newContextWithRPCInfo(ctx) | ||||
| 	sh := cc.dopts.copts.StatsHandler | ||||
| 	if sh != nil { | ||||
| 		ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) | ||||
|  | @ -193,14 +194,17 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth | |||
| 
 | ||||
| 		s, err = t.NewStream(ctx, callHdr) | ||||
| 		if err != nil { | ||||
| 			if _, ok := err.(transport.ConnectionError); ok && put != nil { | ||||
| 				// If error is connection error, transport was sending data on wire,
 | ||||
| 				// and we are not sure if anything has been sent on wire.
 | ||||
| 				// If error is not connection error, we are sure nothing has been sent.
 | ||||
| 				updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false}) | ||||
| 			} | ||||
| 			if put != nil { | ||||
| 				put() | ||||
| 				put = nil | ||||
| 			} | ||||
| 			if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { | ||||
| 				if c.failFast { | ||||
| 					return nil, toRPCErr(err) | ||||
| 				} | ||||
| 			if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { | ||||
| 				continue | ||||
| 			} | ||||
| 			return nil, toRPCErr(err) | ||||
|  | @ -463,6 +467,10 @@ func (cs *clientStream) finish(err error) { | |||
| 		o.after(&cs.c) | ||||
| 	} | ||||
| 	if cs.put != nil { | ||||
| 		updateRPCInfoInContext(cs.s.Context(), rpcInfo{ | ||||
| 			bytesSent:     cs.s.BytesSent(), | ||||
| 			bytesReceived: cs.s.BytesReceived(), | ||||
| 		}) | ||||
| 		cs.put() | ||||
| 		cs.put = nil | ||||
| 	} | ||||
|  |  | |||
|  | @ -493,6 +493,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea | |||
| 			return nil, connectionErrorf(true, err, "transport: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 	s.bytesSent = true | ||||
| 
 | ||||
| 	if t.statsHandler != nil { | ||||
| 		outHeader := &stats.OutHeader{ | ||||
| 			Client:      true, | ||||
|  | @ -958,6 +960,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { | |||
| 	if !ok { | ||||
| 		return | ||||
| 	} | ||||
| 	s.bytesReceived = true | ||||
| 	var state decodeState | ||||
| 	for _, hf := range frame.Fields { | ||||
| 		if err := state.processHeaderField(hf); err != nil { | ||||
|  |  | |||
|  | @ -220,6 +220,10 @@ type Stream struct { | |||
| 	rstStream bool | ||||
| 	// rstError is the error that needs to be sent along with the RST_STREAM frame.
 | ||||
| 	rstError http2.ErrCode | ||||
| 	// bytesSent and bytesReceived indicates whether any bytes have been sent or
 | ||||
| 	// received on this stream.
 | ||||
| 	bytesSent     bool | ||||
| 	bytesReceived bool | ||||
| } | ||||
| 
 | ||||
| // RecvCompress returns the compression algorithm applied to the inbound
 | ||||
|  | @ -341,6 +345,20 @@ func (s *Stream) finish(st *status.Status) { | |||
| 	close(s.done) | ||||
| } | ||||
| 
 | ||||
| // BytesSent indicates whether any bytes have been sent on this stream.
 | ||||
| func (s *Stream) BytesSent() bool { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
| 	return s.bytesSent | ||||
| } | ||||
| 
 | ||||
| // BytesReceived indicates whether any bytes have been received on this stream.
 | ||||
| func (s *Stream) BytesReceived() bool { | ||||
| 	s.mu.Lock() | ||||
| 	defer s.mu.Unlock() | ||||
| 	return s.bytesReceived | ||||
| } | ||||
| 
 | ||||
| // GoString is implemented by Stream so context.String() won't
 | ||||
| // race when printing %#v.
 | ||||
| func (s *Stream) GoString() string { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue