diff --git a/interop/xds/client/client.go b/interop/xds/client/client.go index b119bcd2b..7afdb20e8 100644 --- a/interop/xds/client/client.go +++ b/interop/xds/client/client.go @@ -38,6 +38,10 @@ import ( _ "google.golang.org/grpc/xds" ) +func init() { + rpcCfgs.Store([]*rpcConfig{{typ: unaryCall}}) +} + type statsWatcherKey struct { startID int32 endID int32 @@ -73,21 +77,84 @@ func (watcher *statsWatcher) buildResp() *testpb.LoadBalancerStatsResponse { } } +type accumulatedStats struct { + mu sync.Mutex + numRpcsStartedByMethod map[string]int32 + numRpcsSucceededByMethod map[string]int32 + numRpcsFailedByMethod map[string]int32 +} + +// copyStatsMap makes a copy of the map, and also replaces the RPC type string +// to the proto string. E.g. "UnaryCall" -> "UNARY_CALL". +func copyStatsMap(originalMap map[string]int32) (newMap map[string]int32) { + newMap = make(map[string]int32) + for k, v := range originalMap { + var kk string + switch k { + case unaryCall: + kk = testpb.ClientConfigureRequest_UNARY_CALL.String() + case emptyCall: + kk = testpb.ClientConfigureRequest_EMPTY_CALL.String() + default: + logger.Warningf("unrecognized rpc type: %s", k) + } + if kk == "" { + continue + } + newMap[kk] = v + } + return newMap +} + +func (as *accumulatedStats) buildResp() *testpb.LoadBalancerAccumulatedStatsResponse { + as.mu.Lock() + defer as.mu.Unlock() + return &testpb.LoadBalancerAccumulatedStatsResponse{ + NumRpcsStartedByMethod: copyStatsMap(as.numRpcsStartedByMethod), + NumRpcsSucceededByMethod: copyStatsMap(as.numRpcsSucceededByMethod), + NumRpcsFailedByMethod: copyStatsMap(as.numRpcsFailedByMethod), + } +} + +func (as *accumulatedStats) startRPC(rpcType string) { + as.mu.Lock() + defer as.mu.Unlock() + as.numRpcsStartedByMethod[rpcType]++ +} + +func (as *accumulatedStats) finishRPC(rpcType string, failed bool) { + as.mu.Lock() + defer as.mu.Unlock() + if failed { + as.numRpcsFailedByMethod[rpcType]++ + return + } + as.numRpcsSucceededByMethod[rpcType]++ +} + var ( failOnFailedRPC = flag.Bool("fail_on_failed_rpc", false, "Fail client if any RPCs fail after first success") numChannels = flag.Int("num_channels", 1, "Num of channels") printResponse = flag.Bool("print_response", false, "Write RPC response to stdout") qps = flag.Int("qps", 1, "QPS per channel, for each type of RPC") - rpc = flag.String("rpc", "UnaryCall", "Types of RPCs to make, ',' separated string. RPCs can be EmptyCall or UnaryCall") - rpcMetadata = flag.String("metadata", "", "The metadata to send with RPC, in format EmptyCall:key1:value1,UnaryCall:key2:value2") + rpc = flag.String("rpc", "UnaryCall", "Types of RPCs to make, ',' separated string. RPCs can be EmptyCall or UnaryCall. Deprecated: Use Configure RPC to XdsUpdateClientConfigureServiceServer instead.") + rpcMetadata = flag.String("metadata", "", "The metadata to send with RPC, in format EmptyCall:key1:value1,UnaryCall:key2:value2. Deprecated: Use Configure RPC to XdsUpdateClientConfigureServiceServer instead.") rpcTimeout = flag.Duration("rpc_timeout", 20*time.Second, "Per RPC timeout") server = flag.String("server", "localhost:8080", "Address of server to connect to") statsPort = flag.Int("stats_port", 8081, "Port to expose peer distribution stats service") + rpcCfgs atomic.Value + mu sync.Mutex currentRequestID int32 watchers = make(map[statsWatcherKey]*statsWatcher) + accStats = accumulatedStats{ + numRpcsStartedByMethod: make(map[string]int32), + numRpcsSucceededByMethod: make(map[string]int32), + numRpcsFailedByMethod: make(map[string]int32), + } + // 0 or 1 representing an RPC has succeeded. Use hasRPCSucceeded and // setRPCSucceeded to access in a safe manner. rpcSucceeded uint32 @@ -163,6 +230,47 @@ func (s *statsService) GetClientStats(ctx context.Context, in *testpb.LoadBalanc } } +func (s *statsService) GetClientAccumulatedStats(ctx context.Context, in *testpb.LoadBalancerAccumulatedStatsRequest) (*testpb.LoadBalancerAccumulatedStatsResponse, error) { + return accStats.buildResp(), nil +} + +type configureService struct { + testpb.UnimplementedXdsUpdateClientConfigureServiceServer +} + +func (s *configureService) Configure(ctx context.Context, in *testpb.ClientConfigureRequest) (*testpb.ClientConfigureResponse, error) { + rpcsToMD := make(map[testpb.ClientConfigureRequest_RpcType][]string) + for _, typ := range in.GetTypes() { + rpcsToMD[typ] = nil + } + for _, md := range in.GetMetadata() { + typ := md.GetType() + strs, ok := rpcsToMD[typ] + if !ok { + continue + } + rpcsToMD[typ] = append(strs, md.GetKey(), md.GetValue()) + } + cfgs := make([]*rpcConfig, 0, len(rpcsToMD)) + for typ, md := range rpcsToMD { + var rpcType string + switch typ { + case testpb.ClientConfigureRequest_UNARY_CALL: + rpcType = unaryCall + case testpb.ClientConfigureRequest_EMPTY_CALL: + rpcType = emptyCall + default: + return nil, fmt.Errorf("unsupported RPC type: %v", typ) + } + cfgs = append(cfgs, &rpcConfig{ + typ: rpcType, + md: metadata.Pairs(md...), + }) + } + rpcCfgs.Store(cfgs) + return &testpb.ClientConfigureResponse{}, nil +} + const ( unaryCall string = "UnaryCall" emptyCall string = "EmptyCall" @@ -218,7 +326,7 @@ func parseRPCMetadata(rpcMetadataStr string, rpcs []string) []*rpcConfig { func main() { flag.Parse() - rpcCfgs := parseRPCMetadata(*rpcMetadata, parseRPCTypes(*rpc)) + rpcCfgs.Store(parseRPCMetadata(*rpcMetadata, parseRPCTypes(*rpc))) lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *statsPort)) if err != nil { @@ -227,6 +335,7 @@ func main() { s := grpc.NewServer() defer s.Stop() testpb.RegisterLoadBalancerStatsServiceServer(s, &statsService{}) + testpb.RegisterXdsUpdateClientConfigureServiceServer(s, &configureService{}) go s.Serve(lis) clients := make([]testpb.TestServiceClient, *numChannels) @@ -240,7 +349,7 @@ func main() { } ticker := time.NewTicker(time.Second / time.Duration(*qps**numChannels)) defer ticker.Stop() - sendRPCs(clients, rpcCfgs, ticker) + sendRPCs(clients, ticker) } func makeOneRPC(c testpb.TestServiceClient, cfg *rpcConfig) (*peer.Peer, *rpcInfo, error) { @@ -257,6 +366,7 @@ func makeOneRPC(c testpb.TestServiceClient, cfg *rpcConfig) (*peer.Peer, *rpcInf header metadata.MD err error ) + accStats.startRPC(cfg.typ) switch cfg.typ { case unaryCall: var resp *testpb.SimpleResponse @@ -270,8 +380,10 @@ func makeOneRPC(c testpb.TestServiceClient, cfg *rpcConfig) (*peer.Peer, *rpcInf _, err = c.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&p), grpc.Header(&header)) } if err != nil { + accStats.finishRPC(cfg.typ, true) return nil, nil, err } + accStats.finishRPC(cfg.typ, false) hosts := header["hostname"] if len(hosts) > 0 { @@ -280,26 +392,28 @@ func makeOneRPC(c testpb.TestServiceClient, cfg *rpcConfig) (*peer.Peer, *rpcInf return &p, &info, err } -func sendRPCs(clients []testpb.TestServiceClient, cfgs []*rpcConfig, ticker *time.Ticker) { +func sendRPCs(clients []testpb.TestServiceClient, ticker *time.Ticker) { var i int for range ticker.C { - go func(i int) { - // Get and increment request ID, and save a list of watchers that - // are interested in this RPC. - mu.Lock() - savedRequestID := currentRequestID - currentRequestID++ - savedWatchers := []*statsWatcher{} - for key, value := range watchers { - if key.startID <= savedRequestID && savedRequestID < key.endID { - savedWatchers = append(savedWatchers, value) - } + // Get and increment request ID, and save a list of watchers that are + // interested in this RPC. + mu.Lock() + savedRequestID := currentRequestID + currentRequestID++ + savedWatchers := []*statsWatcher{} + for key, value := range watchers { + if key.startID <= savedRequestID && savedRequestID < key.endID { + savedWatchers = append(savedWatchers, value) } - mu.Unlock() + } + mu.Unlock() - c := clients[i] + // Get the RPC metadata configurations from the Configure RPC. + cfgs := rpcCfgs.Load().([]*rpcConfig) - for _, cfg := range cfgs { + c := clients[i] + for _, cfg := range cfgs { + go func(cfg *rpcConfig) { p, info, err := makeOneRPC(c, cfg) for _, watcher := range savedWatchers { @@ -325,8 +439,8 @@ func sendRPCs(clients []testpb.TestServiceClient, cfgs []*rpcConfig, ticker *tim fmt.Printf("RPC %q, failed with %v\n", cfg.typ, err) } } - } - }(i) + }(cfg) + } i = (i + 1) % len(clients) } }