diff --git a/balancer.go b/balancer.go index 0fec7b6a3..ab65049dd 100644 --- a/balancer.go +++ b/balancer.go @@ -403,6 +403,6 @@ type pickFirst struct { *roundRobin } -func pickFirstBalancer(r naming.Resolver) Balancer { +func pickFirstBalancerV1(r naming.Resolver) Balancer { return &pickFirst{&roundRobin{r: r}} } diff --git a/balancer/balancer.go b/balancer/balancer.go index 6d83a1044..84e10b630 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -182,6 +182,10 @@ type Picker interface { // the connectivity states. // // It also generates and updates the Picker used by gRPC to pick SubConns for RPCs. +// +// HandleSubConnectionStateChange, HandleResolvedAddrs and Close are guaranteed +// to be called synchronously from the same goroutine. +// There's no guarantee on picker.Pick, it may be called anytime. type Balancer interface { // HandleSubConnStateChange is called by gRPC when the connectivity state // of sc has changed. @@ -196,6 +200,7 @@ type Balancer interface { // An empty address slice and a non-nil error will be passed if the resolver returns // non-nil error to gRPC. HandleResolvedAddrs([]resolver.Address, error) - // Close closes the balancer. + // Close closes the balancer. The balancer is not required to call + // ClientConn.RemoveSubConn for its existing SubConns. Close() } diff --git a/balancer/roundrobin/roundrobin.go b/balancer/roundrobin/roundrobin.go new file mode 100644 index 000000000..453ff4ecd --- /dev/null +++ b/balancer/roundrobin/roundrobin.go @@ -0,0 +1,241 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package roundrobin defines a roundrobin balancer. Roundrobin balancer is +// installed as one of the default balancers in gRPC, users don't need to +// explicitly install this balancer. +package roundrobin + +import ( + "sync" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/resolver" +) + +// newBuilder creates a new roundrobin balancer builder. +func newBuilder() balancer.Builder { + return &rrBuilder{} +} + +func init() { + balancer.Register(newBuilder()) +} + +type rrBuilder struct{} + +func (*rrBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { + return &rrBalancer{ + cc: cc, + subConns: make(map[resolver.Address]balancer.SubConn), + scStates: make(map[balancer.SubConn]connectivity.State), + csEvltr: &connectivityStateEvaluator{}, + // Initialize picker to a picker that always return + // ErrNoSubConnAvailable, because when state of a SubConn changes, we + // may call UpdateBalancerState with this picker. + picker: newPicker([]balancer.SubConn{}, nil), + } +} + +func (*rrBuilder) Name() string { + return "roundrobin" +} + +type rrBalancer struct { + cc balancer.ClientConn + + csEvltr *connectivityStateEvaluator + state connectivity.State + + subConns map[resolver.Address]balancer.SubConn + scStates map[balancer.SubConn]connectivity.State + picker *picker +} + +func (b *rrBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { + if err != nil { + grpclog.Infof("roundrobin.rrBalancer: HandleResolvedAddrs called with error %v", err) + return + } + grpclog.Infoln("roundrobin.rrBalancer: got new resolved addresses: ", addrs) + // addrsSet is the set converted from addrs, it's used for quick lookup of an address. + addrsSet := make(map[resolver.Address]struct{}) + for _, a := range addrs { + addrsSet[a] = struct{}{} + if _, ok := b.subConns[a]; !ok { + // a is a new address (not existing in b.subConns). + sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{}) + if err != nil { + grpclog.Warningf("roundrobin.rrBalancer: failed to create new SubConn: %v", err) + continue + } + b.subConns[a] = sc + b.scStates[sc] = connectivity.Idle + sc.Connect() + } + } + for a, sc := range b.subConns { + // a was removed by resolver. + if _, ok := addrsSet[a]; !ok { + b.cc.RemoveSubConn(sc) + delete(b.subConns, a) + // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. + // The entry will be deleted in HandleSubConnStateChange. + } + } +} + +// regeneratePicker takes a snapshot of the balancer, and generates a picker +// from it. The picker +// - always returns ErrTransientFailure if the balancer is in TransientFailure, +// - or does round robin selection of all READY SubConns otherwise. +func (b *rrBalancer) regeneratePicker() { + if b.state == connectivity.TransientFailure { + b.picker = newPicker(nil, balancer.ErrTransientFailure) + return + } + var readySCs []balancer.SubConn + for sc, st := range b.scStates { + if st == connectivity.Ready { + readySCs = append(readySCs, sc) + } + } + b.picker = newPicker(readySCs, nil) +} + +func (b *rrBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + grpclog.Infof("roundrobin.rrBalancer: handle SubConn state change: %p, %v", sc, s) + oldS, ok := b.scStates[sc] + if !ok { + grpclog.Infof("roundrobin.rrBalancer: got state changes for an unknown SubConn: %p, %v", sc, s) + return + } + b.scStates[sc] = s + switch s { + case connectivity.Idle: + sc.Connect() + case connectivity.Shutdown: + // When an address was removed by resolver, b called RemoveSubConn but + // kept the sc's state in scStates. Remove state for this sc here. + delete(b.scStates, sc) + } + + oldAggrState := b.state + b.state = b.csEvltr.recordTransition(oldS, s) + + // Regenerate picker when one of the following happens: + // - this sc became ready from not-ready + // - this sc became not-ready from ready + // - the aggregated state of balancer became TransientFailure from non-TransientFailure + // - the aggregated state of balancer became non-TransientFailure from TransientFailure + if (s == connectivity.Ready) != (oldS == connectivity.Ready) || + (b.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) { + b.regeneratePicker() + } + + b.cc.UpdateBalancerState(b.state, b.picker) + return +} + +// Close is a nop because roundrobin balancer doesn't internal state to clean +// up, and it doesn't need to call RemoveSubConn for the SubConns. +func (b *rrBalancer) Close() { +} + +type picker struct { + // If err is not nil, Pick always returns this err. It's immutable after + // picker is created. + err error + + // subConns is the snapshot of the roundrobin balancer when this picker was + // created. The slice is immutable. Each Get() will do a round robin + // selection from it and return the selected SubConn. + subConns []balancer.SubConn + + mu sync.Mutex + next int +} + +func newPicker(scs []balancer.SubConn, err error) *picker { + grpclog.Infof("roundrobinPicker: newPicker called with scs: %v, %v", scs, err) + if err != nil { + return &picker{err: err} + } + return &picker{ + subConns: scs, + } +} + +func (p *picker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + if p.err != nil { + return nil, nil, p.err + } + if len(p.subConns) <= 0 { + return nil, nil, balancer.ErrNoSubConnAvailable + } + + p.mu.Lock() + sc := p.subConns[p.next] + p.next = (p.next + 1) % len(p.subConns) + p.mu.Unlock() + return sc, nil, nil +} + +// connectivityStateEvaluator gets updated by addrConns when their +// states transition, based on which it evaluates the state of +// ClientConn. +type connectivityStateEvaluator struct { + numReady uint64 // Number of addrConns in ready state. + numConnecting uint64 // Number of addrConns in connecting state. + numTransientFailure uint64 // Number of addrConns in transientFailure. +} + +// recordTransition records state change happening in every subConn and based on +// that it evaluates what aggregated state should be. +// It can only transition between Ready, Connecting and TransientFailure. Other states, +// Idle and Shutdown are transitioned into by ClientConn; in the begining of the connection +// before any subConn is created ClientConn is in idle state. In the end when ClientConn +// closes it is in Shutdown state. +// +// recordTransition should only be called synchronously from the same goroutine. +func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State { + // Update counters. + for idx, state := range []connectivity.State{oldState, newState} { + updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. + switch state { + case connectivity.Ready: + cse.numReady += updateVal + case connectivity.Connecting: + cse.numConnecting += updateVal + case connectivity.TransientFailure: + cse.numTransientFailure += updateVal + } + } + + // Evaluate. + if cse.numReady > 0 { + return connectivity.Ready + } + if cse.numConnecting > 0 { + return connectivity.Connecting + } + return connectivity.TransientFailure +} diff --git a/balancer/roundrobin/roundrobin_test.go b/balancer/roundrobin/roundrobin_test.go new file mode 100644 index 000000000..3b4e1305d --- /dev/null +++ b/balancer/roundrobin/roundrobin_test.go @@ -0,0 +1,470 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package roundrobin + +import ( + "fmt" + "net" + "sync" + "testing" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + _ "google.golang.org/grpc/grpclog/glogger" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + testpb "google.golang.org/grpc/test/grpc_testing" + "google.golang.org/grpc/test/leakcheck" +) + +type testServer struct { + testpb.TestServiceServer +} + +func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil +} + +func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { + return nil +} + +type test struct { + servers []*grpc.Server + addresses []string +} + +func (t *test) cleanup() { + for _, s := range t.servers { + s.Stop() + } +} + +func startTestServers(count int) (_ *test, err error) { + t := &test{} + + defer func() { + if err != nil { + for _, s := range t.servers { + s.Stop() + } + } + }() + for i := 0; i < count; i++ { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + return nil, fmt.Errorf("Failed to listen %v", err) + } + + s := grpc.NewServer() + testpb.RegisterTestServiceServer(s, &testServer{}) + t.servers = append(t.servers, s) + t.addresses = append(t.addresses, lis.Addr().String()) + + go func(s *grpc.Server, l net.Listener) { + s.Serve(l) + }(s, lis) + } + + return t, nil +} + +func TestOneBackend(t *testing.T) { + defer leakcheck.Check(t) + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + test, err := startTestServers(1) + if err != nil { + t.Fatalf("failed to start servers: %v", err) + } + defer test.cleanup() + + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + testc := testpb.NewTestServiceClient(cc) + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}}) + // The second RPC should succeed. + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } +} + +func TestBackendsRoundRobin(t *testing.T) { + defer leakcheck.Check(t) + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + backendCount := 5 + test, err := startTestServers(backendCount) + if err != nil { + t.Fatalf("failed to start servers: %v", err) + } + defer test.cleanup() + + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + testc := testpb.NewTestServiceClient(cc) + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + var resolvedAddrs []resolver.Address + for i := 0; i < backendCount; i++ { + resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]}) + } + + r.NewAddress(resolvedAddrs) + var p peer.Peer + // Make sure connections to all servers are up. + for si := 0; si < backendCount; si++ { + var connected bool + for i := 0; i < 1000; i++ { + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + if p.Addr.String() == test.addresses[si] { + connected = true + break + } + time.Sleep(time.Millisecond) + } + if !connected { + t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si]) + } + } + + for i := 0; i < 3*backendCount; i++ { + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + if p.Addr.String() != test.addresses[i%backendCount] { + t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String()) + } + } +} + +func TestAddressesRemoved(t *testing.T) { + defer leakcheck.Check(t) + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + test, err := startTestServers(1) + if err != nil { + t.Fatalf("failed to start servers: %v", err) + } + defer test.cleanup() + + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + testc := testpb.NewTestServiceClient(cc) + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}}) + // The second RPC should succeed. + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{}) + for i := 0; i < 1000; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) == codes.DeadlineExceeded { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("No RPC failed after removing all addresses, want RPC to fail with DeadlineExceeded") +} + +func TestCloseWithPendingRPC(t *testing.T) { + defer leakcheck.Check(t) + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + test, err := startTestServers(1) + if err != nil { + t.Fatalf("failed to start servers: %v", err) + } + defer test.cleanup() + + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + testc := testpb.NewTestServiceClient(cc) + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + // This RPC blocks until cc is closed. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) == codes.DeadlineExceeded { + t.Errorf("RPC failed because of deadline after cc is closed; want error the client connection is closing") + } + cancel() + }() + } + cc.Close() + wg.Wait() +} + +func TestNewAddressWhileBlocking(t *testing.T) { + defer leakcheck.Check(t) + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + test, err := startTestServers(1) + if err != nil { + t.Fatalf("failed to start servers: %v", err) + } + defer test.cleanup() + + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + testc := testpb.NewTestServiceClient(cc) + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}}) + // The second RPC should succeed. + ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, nil", err) + } + + r.NewAddress([]resolver.Address{}) + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + // This RPC blocks until NewAddress is called. + testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)) + }() + } + time.Sleep(50 * time.Millisecond) + r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}}) + wg.Wait() +} + +func TestOneServerDown(t *testing.T) { + defer leakcheck.Check(t) + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + backendCount := 3 + test, err := startTestServers(backendCount) + if err != nil { + t.Fatalf("failed to start servers: %v", err) + } + defer test.cleanup() + + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + testc := testpb.NewTestServiceClient(cc) + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + var resolvedAddrs []resolver.Address + for i := 0; i < backendCount; i++ { + resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]}) + } + + r.NewAddress(resolvedAddrs) + var p peer.Peer + // Make sure connections to all servers are up. + for si := 0; si < backendCount; si++ { + var connected bool + for i := 0; i < 1000; i++ { + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + if p.Addr.String() == test.addresses[si] { + connected = true + break + } + time.Sleep(time.Millisecond) + } + if !connected { + t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si]) + } + } + + for i := 0; i < 3*backendCount; i++ { + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + if p.Addr.String() != test.addresses[i%backendCount] { + t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String()) + } + } + + // Stop one server, RPCs should roundrobin among the remaining servers. + backendCount-- + test.servers[backendCount].Stop() + // Loop until see server[backendCount-1] twice without seeing server[backendCount]. + var targetSeen int + for i := 0; i < 1000; i++ { + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + switch p.Addr.String() { + case test.addresses[backendCount-1]: + targetSeen++ + case test.addresses[backendCount]: + // Reset targetSeen if peer is server[backendCount]. + targetSeen = 0 + } + // Break to make sure the last picked address is server[-1], so the following for loop won't be flaky. + if targetSeen >= 2 { + break + } + } + if targetSeen != 2 { + t.Fatal("Failed to see server[backendCount-1] twice without seeing server[backendCount]") + } + for i := 0; i < 3*backendCount; i++ { + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + if p.Addr.String() != test.addresses[i%backendCount] { + t.Errorf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String()) + } + } +} + +func TestAllServersDown(t *testing.T) { + defer leakcheck.Check(t) + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + backendCount := 3 + test, err := startTestServers(backendCount) + if err != nil { + t.Fatalf("failed to start servers: %v", err) + } + defer test.cleanup() + + cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerBuilder(newBuilder())) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + testc := testpb.NewTestServiceClient(cc) + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + var resolvedAddrs []resolver.Address + for i := 0; i < backendCount; i++ { + resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]}) + } + + r.NewAddress(resolvedAddrs) + var p peer.Peer + // Make sure connections to all servers are up. + for si := 0; si < backendCount; si++ { + var connected bool + for i := 0; i < 1000; i++ { + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + if p.Addr.String() == test.addresses[si] { + connected = true + break + } + time.Sleep(time.Millisecond) + } + if !connected { + t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si]) + } + } + + for i := 0; i < 3*backendCount; i++ { + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + if p.Addr.String() != test.addresses[i%backendCount] { + t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String()) + } + } + + // All servers are stopped, failfast RPC should fail with unavailable. + for i := 0; i < backendCount; i++ { + test.servers[i].Stop() + } + time.Sleep(100 * time.Millisecond) + for i := 0; i < 1000; i++ { + if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) == codes.Unavailable { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("Failfast RPCs didn't fail with Unavailable after all servers are stopped") +} diff --git a/balancer_conn_wrappers.go b/balancer_conn_wrappers.go index 404377dbc..e4a95fd5c 100644 --- a/balancer_conn_wrappers.go +++ b/balancer_conn_wrappers.go @@ -27,20 +27,141 @@ import ( "google.golang.org/grpc/resolver" ) -// TODO(bar) move ClientConn methods to clientConn file. +// scStateUpdate contains the subConn and the new state it changed to. +type scStateUpdate struct { + sc balancer.SubConn + state connectivity.State +} -func (cc *ClientConn) updatePicker(p balancer.Picker) { - // TODO(bar) add a goroutine and sync it. - // TODO(bar) implement blocking behavior and unblock the previous pick. - cc.pmu.Lock() - cc.picker = p - cc.pmu.Unlock() +// scStateUpdateBuffer is an unbounded channel for scStateChangeTuple. +// TODO make a general purpose buffer that uses interface{}. +type scStateUpdateBuffer struct { + c chan *scStateUpdate + mu sync.Mutex + backlog []*scStateUpdate +} + +func newSCStateUpdateBuffer() *scStateUpdateBuffer { + return &scStateUpdateBuffer{ + c: make(chan *scStateUpdate, 1), + } +} + +func (b *scStateUpdateBuffer) put(t *scStateUpdate) { + b.mu.Lock() + defer b.mu.Unlock() + if len(b.backlog) == 0 { + select { + case b.c <- t: + return + default: + } + } + b.backlog = append(b.backlog, t) +} + +func (b *scStateUpdateBuffer) load() { + b.mu.Lock() + defer b.mu.Unlock() + if len(b.backlog) > 0 { + select { + case b.c <- b.backlog[0]: + b.backlog[0] = nil + b.backlog = b.backlog[1:] + default: + } + } +} + +// get returns the channel that receives a recvMsg in the buffer. +// +// Upon receiving, the caller should call load to send another +// scStateChangeTuple onto the channel if there is any. +func (b *scStateUpdateBuffer) get() <-chan *scStateUpdate { + return b.c +} + +// resolverUpdate contains the new resolved addresses or error if there's +// any. +type resolverUpdate struct { + addrs []resolver.Address + err error } // ccBalancerWrapper is a wrapper on top of cc for balancers. // It implements balancer.ClientConn interface. type ccBalancerWrapper struct { - cc *ClientConn + cc *ClientConn + balancer balancer.Balancer + stateChangeQueue *scStateUpdateBuffer + resolverUpdateCh chan *resolverUpdate + done chan struct{} +} + +func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.BuildOptions) *ccBalancerWrapper { + ccb := &ccBalancerWrapper{ + cc: cc, + stateChangeQueue: newSCStateUpdateBuffer(), + resolverUpdateCh: make(chan *resolverUpdate, 1), + done: make(chan struct{}), + } + go ccb.watcher() + ccb.balancer = b.Build(ccb, bopts) + return ccb +} + +// watcher balancer functions sequencially, so the balancer can be implemeneted +// lock-free. +func (ccb *ccBalancerWrapper) watcher() { + for { + select { + case t := <-ccb.stateChangeQueue.get(): + ccb.stateChangeQueue.load() + ccb.balancer.HandleSubConnStateChange(t.sc, t.state) + case t := <-ccb.resolverUpdateCh: + ccb.balancer.HandleResolvedAddrs(t.addrs, t.err) + case <-ccb.done: + } + + select { + case <-ccb.done: + ccb.balancer.Close() + return + default: + } + } +} + +func (ccb *ccBalancerWrapper) close() { + close(ccb.done) +} + +func (ccb *ccBalancerWrapper) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + // When updating addresses for a SubConn, if the address in use is not in + // the new addresses, the old ac will be tearDown() and a new ac will be + // created. tearDown() generates a state change with Shutdown state, we + // don't want the balancer to receive this state change. So before + // tearDown() on the old ac, ac.acbw (acWrapper) will be set to nil, and + // this function will be called with (nil, Shutdown). We don't need to call + // balancer method in this case. + if sc == nil { + return + } + ccb.stateChangeQueue.put(&scStateUpdate{ + sc: sc, + state: s, + }) +} + +func (ccb *ccBalancerWrapper) handleResolvedAddrs(addrs []resolver.Address, err error) { + select { + case <-ccb.resolverUpdateCh: + default: + } + ccb.resolverUpdateCh <- &resolverUpdate{ + addrs: addrs, + err: err, + } } func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { @@ -64,8 +185,9 @@ func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) { } func (ccb *ccBalancerWrapper) UpdateBalancerState(s connectivity.State, p balancer.Picker) { - // TODO(bar) update cc connectivity state. - ccb.cc.updatePicker(p) + grpclog.Infof("ccBalancerWrapper: updating state and picker called by balancer: %v, %p", s, p) + ccb.cc.csMgr.updateState(s) + ccb.cc.blockingpicker.updatePicker(p) } func (ccb *ccBalancerWrapper) Target() string { @@ -83,11 +205,14 @@ func (acbw *acBalancerWrapper) UpdateAddresses(addrs []resolver.Address) { grpclog.Infof("acBalancerWrapper: UpdateAddresses called with %v", addrs) acbw.mu.Lock() defer acbw.mu.Unlock() - // TODO(bar) update the addresses or tearDown and create a new ac. if !acbw.ac.tryUpdateAddrs(addrs) { cc := acbw.ac.cc acbw.ac.mu.Lock() - // Set old ac.acbw to nil so the states update will be ignored by balancer. + // Set old ac.acbw to nil so the Shutdown state update will be ignored + // by balancer. + // + // TODO(bar) the state transition could be wrong when tearDown() old ac + // and creating new ac, fix the transition. acbw.ac.acbw = nil acbw.ac.mu.Unlock() acState := acbw.ac.getState() diff --git a/balancer_test.go b/balancer_test.go index 988859b6c..29dbe0a67 100644 --- a/balancer_test.go +++ b/balancer_test.go @@ -28,6 +28,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/codes" + _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/naming" "google.golang.org/grpc/test/leakcheck" ) @@ -456,7 +457,7 @@ func TestPickFirstEmptyAddrs(t *testing.T) { defer leakcheck.Check(t) servers, r, cleanup := startServers(t, 1, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancer(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -488,7 +489,7 @@ func TestPickFirstCloseWithPendingRPC(t *testing.T) { defer leakcheck.Check(t) servers, r, cleanup := startServers(t, 1, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancer(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -542,7 +543,7 @@ func TestPickFirstOrderAllServerUp(t *testing.T) { numServers := 3 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancer(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -643,7 +644,7 @@ func TestPickFirstOrderAllServerUp(t *testing.T) { } for i := 0; i < 20; i++ { if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { - t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 2, err, servers[2].port) + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) } time.Sleep(10 * time.Millisecond) } @@ -655,7 +656,7 @@ func TestPickFirstOrderOneServerDown(t *testing.T) { numServers := 3 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancer(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("foo.bar.com", WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } @@ -746,7 +747,7 @@ func TestPickFirstOneAddressRemoval(t *testing.T) { numServers := 2 servers, r, cleanup := startServers(t, numServers, math.MaxUint32) defer cleanup() - cc, err := Dial("localhost:"+servers[0].port, WithBalancer(pickFirstBalancer(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + cc, err := Dial("localhost:"+servers[0].port, WithBalancer(pickFirstBalancerV1(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) if err != nil { t.Fatalf("Failed to create ClientConn: %v", err) } diff --git a/balancer_v1_wrapper.go b/balancer_v1_wrapper.go index b6002b739..7d854d42f 100644 --- a/balancer_v1_wrapper.go +++ b/balancer_v1_wrapper.go @@ -23,6 +23,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/resolver" @@ -45,6 +46,8 @@ func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.B startCh: make(chan struct{}), conns: make(map[resolver.Address]balancer.SubConn), connSt: make(map[balancer.SubConn]*scState), + csEvltr: &connectivityStateEvaluator{}, + state: connectivity.Idle, } cc.UpdateBalancerState(connectivity.Idle, bw) go bw.lbWatcher() @@ -67,6 +70,10 @@ type balancerWrapper struct { cc balancer.ClientConn + // To aggregate the connectivity state. + csEvltr *connectivityStateEvaluator + state connectivity.State + mu sync.Mutex conns map[resolver.Address]balancer.SubConn connSt map[balancer.SubConn]*scState @@ -134,7 +141,7 @@ func (bw *balancerWrapper) lbWatcher() { newAddr := resolver.Address{ Addr: a.Addr, Type: resolver.Backend, // All addresses from balancer are all backends. - ServerName: "", // TODO(bar) support servername. + ServerName: "", Metadata: a.Metadata, } newAddrs = append(newAddrs, newAddr) @@ -173,7 +180,7 @@ func (bw *balancerWrapper) lbWatcher() { resAddrs[resolver.Address{ Addr: a.Addr, Type: resolver.Backend, // All addresses from balancer are all backends. - ServerName: "", // TODO(bar) support servername. + ServerName: "", Metadata: a.Metadata, }] = a } @@ -187,7 +194,7 @@ func (bw *balancerWrapper) lbWatcher() { if _, ok := resAddrs[a]; !ok { del = append(del, c) delete(bw.conns, a) - delete(bw.connSt, c) + // Keep the state of this sc in bw.connSt until its state becomes Shutdown. } } bw.mu.Unlock() @@ -230,12 +237,18 @@ func (bw *balancerWrapper) HandleSubConnStateChange(sc balancer.SubConn, s conne scSt.down = bw.balancer.Up(scSt.addr) } else if oldS == connectivity.Ready && s != connectivity.Ready { if scSt.down != nil { - scSt.down(errConnClosing) // TODO(bar) what error to use? + scSt.down(errConnClosing) } } - // The connectivity state is ignored by clientConn now. - // TODO(bar) use the aggregated connectivity state. - bw.cc.UpdateBalancerState(connectivity.Ready, bw) + sa := bw.csEvltr.recordTransition(oldS, s) + if bw.state != sa { + bw.state = sa + } + bw.cc.UpdateBalancerState(bw.state, bw) + if s == connectivity.Shutdown { + // Remove state for this sc. + delete(bw.connSt, sc) + } return } @@ -276,27 +289,79 @@ func (bw *balancerWrapper) Pick(ctx context.Context, opts balancer.PickOptions) if err != nil { return nil, nil, err } - var put func(balancer.DoneInfo) + var done func(balancer.DoneInfo) if p != nil { - put = func(i balancer.DoneInfo) { p() } + done = func(i balancer.DoneInfo) { p() } } var sc balancer.SubConn + bw.mu.Lock() + defer bw.mu.Unlock() if bw.pickfirst { - bw.mu.Lock() // Get the first sc in conns. for _, sc = range bw.conns { break } - bw.mu.Unlock() } else { - bw.mu.Lock() - sc = bw.conns[resolver.Address{ + var ok bool + sc, ok = bw.conns[resolver.Address{ Addr: a.Addr, Type: resolver.Backend, - ServerName: "", // TODO(bar) support servername. + ServerName: "", Metadata: a.Metadata, }] - bw.mu.Unlock() + if !ok && failfast { + return nil, nil, Errorf(codes.Unavailable, "there is no connection available") + } + if s, ok := bw.connSt[sc]; failfast && (!ok || s.s != connectivity.Ready) { + // If the returned sc is not ready and RPC is failfast, + // return error, and this RPC will fail. + return nil, nil, Errorf(codes.Unavailable, "there is no connection available") + } } - return sc, put, nil + + return sc, done, nil +} + +// connectivityStateEvaluator gets updated by addrConns when their +// states transition, based on which it evaluates the state of +// ClientConn. +type connectivityStateEvaluator struct { + mu sync.Mutex + numReady uint64 // Number of addrConns in ready state. + numConnecting uint64 // Number of addrConns in connecting state. + numTransientFailure uint64 // Number of addrConns in transientFailure. +} + +// recordTransition records state change happening in every subConn and based on +// that it evaluates what aggregated state should be. +// It can only transition between Ready, Connecting and TransientFailure. Other states, +// Idle and Shutdown are transitioned into by ClientConn; in the begining of the connection +// before any subConn is created ClientConn is in idle state. In the end when ClientConn +// closes it is in Shutdown state. +// TODO Note that in later releases, a ClientConn with no activity will be put into an Idle state. +func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) connectivity.State { + cse.mu.Lock() + defer cse.mu.Unlock() + + // Update counters. + for idx, state := range []connectivity.State{oldState, newState} { + updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. + switch state { + case connectivity.Ready: + cse.numReady += updateVal + case connectivity.Connecting: + cse.numConnecting += updateVal + case connectivity.TransientFailure: + cse.numTransientFailure += updateVal + } + } + + // Evaluate. + if cse.numReady > 0 { + return connectivity.Ready + } + if cse.numConnecting > 0 { + return connectivity.Connecting + } + return connectivity.TransientFailure } diff --git a/call.go b/call.go index 1c7d1c135..1ef2507c3 100644 --- a/call.go +++ b/call.go @@ -207,9 +207,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli err error t transport.ClientTransport stream *transport.Stream - // Record the put handler from Balancer.Get(...). It is called once the + // Record the done handler from Balancer.Get(...). It is called once the // RPC has completed or failed. - put func(balancer.DoneInfo) + done func(balancer.DoneInfo) ) // TODO(zhaoq): Need a formal spec of fail-fast. callHdr := &transport.CallHdr{ @@ -223,10 +223,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli callHdr.Creds = c.creds } - gopts := BalancerGetOptions{ - BlockingWait: !c.failFast, - } - t, put, err = cc.getTransport(ctx, gopts) + t, done, err = cc.getTransport(ctx, c.failFast) if err != nil { // TODO(zhaoq): Probably revisit the error handling. if _, ok := status.FromError(err); ok { @@ -246,14 +243,14 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } stream, err = t.NewStream(ctx, callHdr) if err != nil { - if put != nil { + if done != nil { if _, ok := err.(transport.ConnectionError); ok { // If error is connection error, transport was sending data on wire, // and we are not sure if anything has been sent on wire. // If error is not connection error, we are sure nothing has been sent. updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false}) } - put(balancer.DoneInfo{Err: err}) + done(balancer.DoneInfo{Err: err}) } if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { continue @@ -265,12 +262,12 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts) if err != nil { - if put != nil { + if done != nil { updateRPCInfoInContext(ctx, rpcInfo{ bytesSent: stream.BytesSent(), bytesReceived: stream.BytesReceived(), }) - put(balancer.DoneInfo{Err: err}) + done(balancer.DoneInfo{Err: err}) } // Retry a non-failfast RPC when // i) there is a connection error; or @@ -282,12 +279,12 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } err = recvResponse(ctx, cc.dopts, t, c, stream, reply) if err != nil { - if put != nil { + if done != nil { updateRPCInfoInContext(ctx, rpcInfo{ bytesSent: stream.BytesSent(), bytesReceived: stream.BytesReceived(), }) - put(balancer.DoneInfo{Err: err}) + done(balancer.DoneInfo{Err: err}) } if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { continue @@ -298,12 +295,12 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) } t.CloseStream(stream, nil) - if put != nil { + if done != nil { updateRPCInfoInContext(ctx, rpcInfo{ bytesSent: stream.BytesSent(), bytesReceived: stream.BytesReceived(), }) - put(balancer.DoneInfo{Err: err}) + done(balancer.DoneInfo{Err: err}) } return stream.Status().Err() } diff --git a/call_test.go b/call_test.go index 7dbe52d04..f48d30e87 100644 --- a/call_test.go +++ b/call_test.go @@ -117,6 +117,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { type server struct { lis net.Listener port string + addr string startedErr chan error // sent nil or an error after server starts mu sync.Mutex conns map[transport.ServerTransport]bool @@ -138,7 +139,8 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) { s.startedErr <- fmt.Errorf("failed to listen: %v", err) return } - _, p, err := net.SplitHostPort(s.lis.Addr().String()) + s.addr = s.lis.Addr().String() + _, p, err := net.SplitHostPort(s.addr) if err != nil { s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err) return diff --git a/clientconn.go b/clientconn.go index d0d479b48..7a61f9c53 100644 --- a/clientconn.go +++ b/clientconn.go @@ -20,6 +20,7 @@ package grpc import ( "errors" + "fmt" "math" "net" "reflect" @@ -178,6 +179,15 @@ func WithBalancer(b Balancer) DialOption { } } +// WithBalancerBuilder is for testing only. Users using custom balancers should +// register their balancer and use service config to choose the balancer to use. +func WithBalancerBuilder(b balancer.Builder) DialOption { + // TODO(bar) remove this when switching balancer is done. + return func(o *dialOptions) { + o.balancerBuilder = b + } +} + // WithServiceConfig returns a DialOption which has a channel to read the service configuration. func WithServiceConfig(c <-chan ServiceConfig) DialOption { return func(o *dialOptions) { @@ -339,13 +349,30 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * target: target, csMgr: &connectivityStateManager{}, conns: make(map[*addrConn]struct{}), + + blockingpicker: newPickerWrapper(), } - cc.csEvltr = &connectivityStateEvaluator{csMgr: cc.csMgr} cc.ctx, cc.cancel = context.WithCancel(context.Background()) for _, opt := range opts { opt(&cc.dopts) } + + if !cc.dopts.insecure { + if cc.dopts.copts.TransportCredentials == nil { + return nil, errNoTransportSecurity + } + } else { + if cc.dopts.copts.TransportCredentials != nil { + return nil, errCredentialsConflict + } + for _, cd := range cc.dopts.copts.PerRPCCredentials { + if cd.RequireTransportSecurity() { + return nil, errTransportCredentialsMissing + } + } + } + cc.mkp = cc.dopts.copts.KeepaliveParams if cc.dopts.copts.Dialer == nil { @@ -408,7 +435,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * cc.authority = target } - // TODO(bar) parse scheme and start resolver. if cc.dopts.balancerBuilder != nil { var credsClone credentials.TransportCredentials if creds != nil { @@ -420,7 +446,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * } // Build should not take long time. So it's ok to not have a goroutine for it. // TODO(bar) init balancer after first resolver result to support service config balancer. - cc.balancer = cc.dopts.balancerBuilder.Build(&ccBalancerWrapper{cc: cc}, buildOpts) + cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, buildOpts) } else { waitC := make(chan error, 1) go func() { @@ -460,11 +486,18 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * go cc.scWatcher() } - if cc.balancer != nil { - // Unblock balancer initialization with a fake resolver update. + // Build the resolver. + cc.resolverWrapper, err = newCCResolverWrapper(cc) + if err != nil { + return nil, fmt.Errorf("failed to build resolver: %v", err) + } + + if cc.balancerWrapper != nil && cc.resolverWrapper == nil { + // TODO(bar) there should always be a resolver (DNS as the default). + // Unblock balancer initialization with a fake resolver update if there's no resolver. // The balancer wrapper will not read the addresses, so an empty list works. // TODO(bar) remove this after the real resolver is started. - cc.balancer.HandleResolvedAddrs([]resolver.Address{}, nil) + cc.balancerWrapper.handleResolvedAddrs([]resolver.Address{}, nil) } // A blocking dial blocks until the clientConn is ready. @@ -484,54 +517,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * return cc, nil } -// connectivityStateEvaluator gets updated by addrConns when their -// states transition, based on which it evaluates the state of -// ClientConn. -// Note: This code will eventually sit in the balancer in the new design. -type connectivityStateEvaluator struct { - csMgr *connectivityStateManager - mu sync.Mutex - numReady uint64 // Number of addrConns in ready state. - numConnecting uint64 // Number of addrConns in connecting state. - numTransientFailure uint64 // Number of addrConns in transientFailure. -} - -// recordTransition records state change happening in every addrConn and based on -// that it evaluates what state the ClientConn is in. -// It can only transition between connectivity.Ready, connectivity.Connecting and connectivity.TransientFailure. Other states, -// Idle and connectivity.Shutdown are transitioned into by ClientConn; in the beginning of the connection -// before any addrConn is created ClientConn is in idle state. In the end when ClientConn -// closes it is in connectivity.Shutdown state. -// TODO Note that in later releases, a ClientConn with no activity will be put into an Idle state. -func (cse *connectivityStateEvaluator) recordTransition(oldState, newState connectivity.State) { - cse.mu.Lock() - defer cse.mu.Unlock() - - // Update counters. - for idx, state := range []connectivity.State{oldState, newState} { - updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new. - switch state { - case connectivity.Ready: - cse.numReady += updateVal - case connectivity.Connecting: - cse.numConnecting += updateVal - case connectivity.TransientFailure: - cse.numTransientFailure += updateVal - } - } - - // Evaluate. - if cse.numReady > 0 { - cse.csMgr.updateState(connectivity.Ready) - return - } - if cse.numConnecting > 0 { - cse.csMgr.updateState(connectivity.Connecting) - return - } - cse.csMgr.updateState(connectivity.TransientFailure) -} - // connectivityStateManager keeps the connectivity.State of ClientConn. // This struct will eventually be exported so the balancers can access it. type connectivityStateManager struct { @@ -584,13 +569,11 @@ type ClientConn struct { authority string dopts dialOptions csMgr *connectivityStateManager - csEvltr *connectivityStateEvaluator // This will eventually be part of balancer. - balancer balancer.Balancer + balancerWrapper *ccBalancerWrapper + resolverWrapper *ccResolverWrapper - // TODO(bar) move the mutex and picker into a struct that does blocking pick(). - pmu sync.Mutex - picker balancer.Picker + blockingpicker *pickerWrapper mu sync.RWMutex sc ServiceConfig @@ -647,7 +630,6 @@ func (cc *ClientConn) newAddrConn(addrs []resolver.Address) (*addrConn, error) { dopts: cc.dopts, } ac.ctx, ac.cancel = context.WithCancel(cc.ctx) - ac.csEvltr = cc.csEvltr // Track ac in cc. This needs to be done before any getTransport(...) is called. cc.mu.Lock() if cc.conns == nil { @@ -674,6 +656,7 @@ func (cc *ClientConn) removeAddrConn(ac *addrConn, err error) { // connect starts to creating transport and also starts the transport monitor // goroutine for this ac. +// It does nothing if the ac is not IDLE. // TODO(bar) Move this to the addrConn section. // This was part of resetAddrConn, keep it here to make the diff look clean. func (ac *addrConn) connect(block bool) error { @@ -682,25 +665,17 @@ func (ac *addrConn) connect(block bool) error { ac.mu.Unlock() return errConnClosing } - ac.mu.Unlock() - - if EnableTracing { - ac.events = trace.NewEventLog("grpc.ClientConn", ac.addrs[0].Addr) + if ac.state != connectivity.Idle { + ac.mu.Unlock() + return nil } - if !ac.dopts.insecure { - if ac.dopts.copts.TransportCredentials == nil { - return errNoTransportSecurity - } + ac.state = connectivity.Connecting + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) } else { - if ac.dopts.copts.TransportCredentials != nil { - return errCredentialsConflict - } - for _, cd := range ac.dopts.copts.PerRPCCredentials { - if cd.RequireTransportSecurity() { - return errTransportCredentialsMissing - } - } + ac.cc.csMgr.updateState(ac.state) } + ac.mu.Unlock() if block { if err := ac.resetTransport(false); err != nil { @@ -780,56 +755,37 @@ func (cc *ClientConn) GetMethodConfig(method string) MethodConfig { return m } -func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(balancer.DoneInfo), error) { - var ( - ac *addrConn - put func(balancer.DoneInfo) - ) - if cc.balancer == nil { +func (cc *ClientConn) getTransport(ctx context.Context, failfast bool) (transport.ClientTransport, func(balancer.DoneInfo), error) { + if cc.balancerWrapper == nil { // If balancer is nil, there should be only one addrConn available. cc.mu.RLock() if cc.conns == nil { cc.mu.RUnlock() + // TODO this function returns toRPCErr and non-toRPCErr. Clean up + // the errors in ClientConn. return nil, nil, toRPCErr(ErrClientConnClosing) } + var ac *addrConn for ac = range cc.conns { // Break after the first iteration to get the first addrConn. break } cc.mu.RUnlock() - } else { - cc.pmu.Lock() - // TODO(bar) call pick on struct blockPicker instead of the real picker. - p := cc.picker - cc.pmu.Unlock() - - var ( - err error - sc balancer.SubConn - ) - sc, put, err = p.Pick(ctx, balancer.PickOptions{}) + if ac == nil { + return nil, nil, errConnClosing + } + t, err := ac.wait(ctx, false /*hasBalancer*/, failfast) if err != nil { - return nil, nil, toRPCErr(err) - } - if acbw, ok := sc.(*acBalancerWrapper); ok { - ac = acbw.getAddrConn() - } else if put != nil { - updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false}) - put(balancer.DoneInfo{Err: errors.New("SubConn returned by pick cannot be recognized")}) + return nil, nil, err } + return t, nil, nil } - if ac == nil { - return nil, nil, errConnClosing - } - t, err := ac.wait(ctx, cc.balancer != nil, !opts.BlockingWait) + + t, done, err := cc.blockingpicker.pick(ctx, failfast, balancer.PickOptions{}) if err != nil { - if put != nil { - updateRPCInfoInContext(ctx, rpcInfo{bytesSent: false, bytesReceived: false}) - put(balancer.DoneInfo{Err: err}) - } - return nil, nil, err + return nil, nil, toRPCErr(err) } - return t, put, nil + return t, done, nil } // Close tears down the ClientConn and all underlying connections. @@ -845,8 +801,12 @@ func (cc *ClientConn) Close() error { cc.conns = nil cc.csMgr.updateState(connectivity.Shutdown) cc.mu.Unlock() - if cc.balancer != nil { - cc.balancer.Close() + cc.blockingpicker.close() + if cc.resolverWrapper != nil { + cc.resolverWrapper.close() + } + if cc.balancerWrapper != nil { + cc.balancerWrapper.close() } for ac := range conns { ac.tearDown(ErrClientConnClosing) @@ -866,8 +826,6 @@ type addrConn struct { events trace.EventLog acbw balancer.SubConn - csEvltr *connectivityStateEvaluator - mu sync.Mutex state connectivity.State // ready is closed and becomes nil when a new transport is up or failed @@ -920,19 +878,19 @@ func (ac *addrConn) resetTransport(drain bool) error { ac.mu.Unlock() return errConnClosing } - oldState := ac.state - ac.state = connectivity.Connecting - ac.csEvltr.recordTransition(oldState, ac.state) - if ac.cc.balancer != nil { - ac.cc.balancer.HandleSubConnStateChange(ac.acbw, ac.state) + ac.state = connectivity.TransientFailure + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) + } else { + ac.cc.csMgr.updateState(ac.state) } - // TODO(bar) don't call balancer functions to handle subconn state change if ac.acbw is nil. if ac.ready != nil { close(ac.ready) ac.ready = nil } t := ac.transport ac.transport = nil + ac.curAddr = resolver.Address{} ac.mu.Unlock() if t != nil && !drain { t.Close() @@ -953,16 +911,17 @@ func (ac *addrConn) resetTransport(drain bool) error { return errConnClosing } ac.printf("connecting") - oldState := ac.state ac.state = connectivity.Connecting - ac.csEvltr.recordTransition(oldState, ac.state) // TODO(bar) remove condition once we always have a balancer. - if ac.cc.balancer != nil { - ac.cc.balancer.HandleSubConnStateChange(ac.acbw, ac.state) + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) + } else { + ac.cc.csMgr.updateState(ac.state) } // copy ac.addrs in case of race addrsIter := make([]resolver.Address, len(ac.addrs)) copy(addrsIter, ac.addrs) + copts := ac.dopts.copts ac.mu.Unlock() for _, addr := range addrsIter { ac.mu.Lock() @@ -977,7 +936,7 @@ func (ac *addrConn) resetTransport(drain bool) error { Addr: addr.Addr, Metadata: addr.Metadata, } - newTransport, err := transport.NewClientTransport(ctx, sinfo, ac.dopts.copts) + newTransport, err := transport.NewClientTransport(ctx, sinfo, copts) // Don't call cancel in success path due to a race in Go 1.6: // https://github.com/golang/go/issues/15078. if err != nil { @@ -1004,13 +963,17 @@ func (ac *addrConn) resetTransport(drain bool) error { newTransport.Close() return errConnClosing } - oldState = ac.state ac.state = connectivity.Ready - ac.csEvltr.recordTransition(oldState, ac.state) - if ac.cc.balancer != nil { - ac.cc.balancer.HandleSubConnStateChange(ac.acbw, ac.state) + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) + } else { + ac.cc.csMgr.updateState(ac.state) } + t := ac.transport ac.transport = newTransport + if t != nil { + t.Close() + } ac.curAddr = addr if ac.ready != nil { close(ac.ready) @@ -1020,11 +983,11 @@ func (ac *addrConn) resetTransport(drain bool) error { return nil } ac.mu.Lock() - oldState = ac.state ac.state = connectivity.TransientFailure - ac.csEvltr.recordTransition(oldState, ac.state) - if ac.cc.balancer != nil { - ac.cc.balancer.HandleSubConnStateChange(ac.acbw, ac.state) + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) + } else { + ac.cc.csMgr.updateState(ac.state) } if ac.ready != nil { close(ac.ready) @@ -1145,6 +1108,28 @@ func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (trans } } +// getReadyTransport returns the transport if ac's state is READY. +// Otherwise it returns nil, false. +// If ac's state is IDLE, it will trigger ac to connect. +func (ac *addrConn) getReadyTransport() (transport.ClientTransport, bool) { + ac.mu.Lock() + if ac.state == connectivity.Ready { + t := ac.transport + ac.mu.Unlock() + return t, true + } + var idle bool + if ac.state == connectivity.Idle { + idle = true + } + ac.mu.Unlock() + // Trigger idle ac to connect. + if idle { + ac.connect(false) + } + return nil, false +} + // tearDown starts to tear down the addrConn. // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // some edge cases (e.g., the caller opens and closes many addrConn's in a @@ -1166,12 +1151,12 @@ func (ac *addrConn) tearDown(err error) { if ac.state == connectivity.Shutdown { return } - oldState := ac.state ac.state = connectivity.Shutdown ac.tearDownErr = err - ac.csEvltr.recordTransition(oldState, ac.state) - if ac.cc.balancer != nil { - ac.cc.balancer.HandleSubConnStateChange(ac.acbw, ac.state) + if ac.cc.balancerWrapper != nil { + ac.cc.balancerWrapper.handleSubConnStateChange(ac.acbw, ac.state) + } else { + ac.cc.csMgr.updateState(ac.state) } if ac.events != nil { ac.events.Finish() diff --git a/grpclb.go b/grpclb.go index ebda2a1e5..db56ff362 100644 --- a/grpclb.go +++ b/grpclb.go @@ -461,6 +461,7 @@ func (b *grpclbBalancer) Start(target string, config BalancerConfig) error { // WithDialer takes a different type of function, so we instead use a special DialOption here. dopts = append(dopts, func(o *dialOptions) { o.copts.Dialer = dialer }) } + dopts = append(dopts, WithBlock()) ccError = make(chan struct{}) ctx, cancel := context.WithTimeout(context.Background(), time.Second) cc, err = DialContext(ctx, rb.addr, dopts...) diff --git a/picker_wrapper.go b/picker_wrapper.go new file mode 100644 index 000000000..9085dbc9c --- /dev/null +++ b/picker_wrapper.go @@ -0,0 +1,141 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "sync" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" + "google.golang.org/grpc/transport" +) + +// pickerWrapper is a wrapper of balancer.Picker. It blocks on certain pick +// actions and unblock when there's a picker update. +type pickerWrapper struct { + mu sync.Mutex + done bool + blockingCh chan struct{} + picker balancer.Picker +} + +func newPickerWrapper() *pickerWrapper { + bp := &pickerWrapper{blockingCh: make(chan struct{})} + return bp +} + +// updatePicker is called by UpdateBalancerState. It unblocks all blocked pick. +func (bp *pickerWrapper) updatePicker(p balancer.Picker) { + bp.mu.Lock() + if bp.done { + bp.mu.Unlock() + return + } + bp.picker = p + // bp.blockingCh should never be nil. + close(bp.blockingCh) + bp.blockingCh = make(chan struct{}) + bp.mu.Unlock() +} + +// pick returns the transport that will be used for the RPC. +// It may block in the following cases: +// - there's no picker +// - the current picker returns ErrNoSubConnAvailable +// - the current picker returns other errors and failfast is false. +// - the subConn returned by the current picker is not READY +// When one of these situations happens, pick blocks until the picker gets updated. +func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer.PickOptions) (transport.ClientTransport, func(balancer.DoneInfo), error) { + var ( + p balancer.Picker + ch chan struct{} + ) + + for { + bp.mu.Lock() + if bp.done { + bp.mu.Unlock() + return nil, nil, ErrClientConnClosing + } + + if bp.picker == nil { + ch = bp.blockingCh + } + if ch == bp.blockingCh { + // This could happen when either: + // - bp.picker is nil (the previous if condition), or + // - has called pick on the current picker. + bp.mu.Unlock() + select { + case <-ctx.Done(): + return nil, nil, ctx.Err() + case <-ch: + } + continue + } + + ch = bp.blockingCh + p = bp.picker + bp.mu.Unlock() + + subConn, put, err := p.Pick(ctx, opts) + + if err != nil { + switch err { + case balancer.ErrNoSubConnAvailable: + continue + case balancer.ErrTransientFailure: + if !failfast { + continue + } + return nil, nil, status.Errorf(codes.Unavailable, "%v", err) + default: + // err is some other error. + return nil, nil, toRPCErr(err) + } + } + + acw, ok := subConn.(*acBalancerWrapper) + if !ok { + grpclog.Infof("subconn returned from pick is not *acBalancerWrapper") + continue + } + if t, ok := acw.getAddrConn().getReadyTransport(); ok { + return t, put, nil + } + grpclog.Infof("blockingPicker: the picked transport is not ready, loop back to repick") + // If ok == false, ac.state is not READY. + // A valid picker always returns READY subConn. This means the state of ac + // just changed, and picker will be updated shortly. + // continue back to the beginning of the for loop to repick. + } +} + +func (bp *pickerWrapper) close() { + bp.mu.Lock() + defer bp.mu.Unlock() + if bp.done { + return + } + bp.done = true + close(bp.blockingCh) +} diff --git a/picker_wrapper_test.go b/picker_wrapper_test.go new file mode 100644 index 000000000..23bc8f243 --- /dev/null +++ b/picker_wrapper_test.go @@ -0,0 +1,160 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "fmt" + "sync/atomic" + "testing" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + _ "google.golang.org/grpc/grpclog/glogger" + "google.golang.org/grpc/test/leakcheck" + "google.golang.org/grpc/transport" +) + +const goroutineCount = 5 + +var ( + testT = &testTransport{} + testSC = &acBalancerWrapper{ac: &addrConn{ + state: connectivity.Ready, + transport: testT, + }} + testSCNotReady = &acBalancerWrapper{ac: &addrConn{ + state: connectivity.TransientFailure, + }} +) + +type testTransport struct { + transport.ClientTransport +} + +type testingPicker struct { + err error + sc balancer.SubConn + maxCalled int64 +} + +func (p *testingPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + if atomic.AddInt64(&p.maxCalled, -1) < 0 { + return nil, nil, fmt.Errorf("Pick called to many times (> goroutineCount)") + } + if p.err != nil { + return nil, nil, p.err + } + return p.sc, nil, nil +} + +func TestBlockingPickTimeout(t *testing.T) { + defer leakcheck.Check(t) + bp := newPickerWrapper() + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + if _, _, err := bp.pick(ctx, true, balancer.PickOptions{}); err != context.DeadlineExceeded { + t.Errorf("bp.pick returned error %v, want DeadlineExceeded", err) + } +} + +func TestBlockingPick(t *testing.T) { + defer leakcheck.Check(t) + bp := newPickerWrapper() + // All goroutines should block because picker is nil in bp. + var finishedCount uint64 + for i := goroutineCount; i > 0; i-- { + go func() { + if tr, _, err := bp.pick(context.Background(), true, balancer.PickOptions{}); err != nil || tr != testT { + t.Errorf("bp.pick returned non-nil error: %v", err) + } + atomic.AddUint64(&finishedCount, 1) + }() + } + time.Sleep(50 * time.Millisecond) + if c := atomic.LoadUint64(&finishedCount); c != 0 { + t.Errorf("finished goroutines count: %v, want 0", c) + } + bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) +} + +func TestBlockingPickNoSubAvailable(t *testing.T) { + defer leakcheck.Check(t) + bp := newPickerWrapper() + var finishedCount uint64 + bp.updatePicker(&testingPicker{err: balancer.ErrNoSubConnAvailable, maxCalled: goroutineCount}) + // All goroutines should block because picker returns no sc avilable. + for i := goroutineCount; i > 0; i-- { + go func() { + if tr, _, err := bp.pick(context.Background(), true, balancer.PickOptions{}); err != nil || tr != testT { + t.Errorf("bp.pick returned non-nil error: %v", err) + } + atomic.AddUint64(&finishedCount, 1) + }() + } + time.Sleep(50 * time.Millisecond) + if c := atomic.LoadUint64(&finishedCount); c != 0 { + t.Errorf("finished goroutines count: %v, want 0", c) + } + bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) +} + +func TestBlockingPickTransientWaitforready(t *testing.T) { + defer leakcheck.Check(t) + bp := newPickerWrapper() + bp.updatePicker(&testingPicker{err: balancer.ErrTransientFailure, maxCalled: goroutineCount}) + var finishedCount uint64 + // All goroutines should block because picker returns transientFailure and + // picks are not failfast. + for i := goroutineCount; i > 0; i-- { + go func() { + if tr, _, err := bp.pick(context.Background(), false, balancer.PickOptions{}); err != nil || tr != testT { + t.Errorf("bp.pick returned non-nil error: %v", err) + } + atomic.AddUint64(&finishedCount, 1) + }() + } + time.Sleep(time.Millisecond) + if c := atomic.LoadUint64(&finishedCount); c != 0 { + t.Errorf("finished goroutines count: %v, want 0", c) + } + bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) +} + +func TestBlockingPickSCNotReady(t *testing.T) { + defer leakcheck.Check(t) + bp := newPickerWrapper() + bp.updatePicker(&testingPicker{sc: testSCNotReady, maxCalled: goroutineCount}) + var finishedCount uint64 + // All goroutines should block because sc is not ready. + for i := goroutineCount; i > 0; i-- { + go func() { + if tr, _, err := bp.pick(context.Background(), true, balancer.PickOptions{}); err != nil || tr != testT { + t.Errorf("bp.pick returned non-nil error: %v", err) + } + atomic.AddUint64(&finishedCount, 1) + }() + } + time.Sleep(time.Millisecond) + if c := atomic.LoadUint64(&finishedCount); c != 0 { + t.Errorf("finished goroutines count: %v, want 0", c) + } + bp.updatePicker(&testingPicker{sc: testSC, maxCalled: goroutineCount}) +} diff --git a/pickfirst.go b/pickfirst.go new file mode 100644 index 000000000..7f993ef5a --- /dev/null +++ b/pickfirst.go @@ -0,0 +1,95 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/resolver" +) + +func newPickfirstBuilder() balancer.Builder { + return &pickfirstBuilder{} +} + +type pickfirstBuilder struct{} + +func (*pickfirstBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { + return &pickfirstBalancer{cc: cc} +} + +func (*pickfirstBuilder) Name() string { + return "pickfirst" +} + +type pickfirstBalancer struct { + cc balancer.ClientConn + sc balancer.SubConn +} + +func (b *pickfirstBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { + if err != nil { + grpclog.Infof("pickfirstBalancer: HandleResolvedAddrs called with error %v", err) + return + } + if b.sc == nil { + b.sc, err = b.cc.NewSubConn(addrs, balancer.NewSubConnOptions{}) + if err != nil { + grpclog.Errorf("pickfirstBalancer: failed to NewSubConn: %v", err) + return + } + b.cc.UpdateBalancerState(connectivity.Idle, &picker{sc: b.sc}) + } else { + b.sc.UpdateAddresses(addrs) + } +} + +func (b *pickfirstBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + grpclog.Infof("pickfirstBalancer: HandleSubConnStateChange: %p, %v", sc, s) + if b.sc != sc || s == connectivity.Shutdown { + b.sc = nil + return + } + + switch s { + case connectivity.Ready, connectivity.Idle: + b.cc.UpdateBalancerState(s, &picker{sc: sc}) + case connectivity.Connecting: + b.cc.UpdateBalancerState(s, &picker{err: balancer.ErrNoSubConnAvailable}) + case connectivity.TransientFailure: + b.cc.UpdateBalancerState(s, &picker{err: balancer.ErrTransientFailure}) + } +} + +func (b *pickfirstBalancer) Close() { +} + +type picker struct { + err error + sc balancer.SubConn +} + +func (p *picker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + if p.err != nil { + return nil, nil, p.err + } + return p.sc, nil, nil +} diff --git a/pickfirst_test.go b/pickfirst_test.go new file mode 100644 index 000000000..f28875437 --- /dev/null +++ b/pickfirst_test.go @@ -0,0 +1,352 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "math" + "sync" + "testing" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/test/leakcheck" +) + +func TestOneBackendPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 1 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}}) + // The second RPC should succeed. + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("EmptyCall() = _, %v, want _, %v", err, servers[0].port) +} + +func TestBackendsPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 2 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) + // The second RPC should succeed with the first server. + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("EmptyCall() = _, %v, want _, %v", err, servers[0].port) +} + +func TestNewAddressWhileBlockingPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 1 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + // This RPC blocks until NewAddress is called. + Invoke(context.Background(), "/foo/bar", &req, &reply, cc, FailFast(false)) + }() + } + time.Sleep(50 * time.Millisecond) + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}}) + wg.Wait() +} + +func TestCloseWithPendingRPCPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 1 + _, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + // This RPC blocks until NewAddress is called. + Invoke(context.Background(), "/foo/bar", &req, &reply, cc, FailFast(false)) + }() + } + time.Sleep(50 * time.Millisecond) + cc.Close() + wg.Wait() +} + +func TestOneServerDownPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 2 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) + // The second RPC should succeed with the first server. + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + break + } + time.Sleep(time.Millisecond) + } + + servers[0].stop() + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("EmptyCall() = _, %v, want _, %v", err, servers[0].port) +} + +func TestAllServersDownPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 2 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}}) + // The second RPC should succeed with the first server. + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + break + } + time.Sleep(time.Millisecond) + } + + for i := 0; i < numServers; i++ { + servers[i].stop() + } + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); Code(err) == codes.Unavailable { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("EmptyCall() = _, %v, want _, error with code unavailable", err) +} + +func TestAddressesRemovedPickfirst(t *testing.T) { + defer leakcheck.Check(t) + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + numServers := 3 + servers, _, scleanup := startServers(t, numServers, math.MaxInt32) + defer scleanup() + + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithBalancerBuilder(newPickfirstBuilder()), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + // The first RPC should fail because there's no address. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req := "port" + var reply string + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc); err == nil || Code(err) != codes.DeadlineExceeded { + t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err) + } + + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}, {Addr: servers[1].addr}, {Addr: servers[2].addr}}) + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + break + } + time.Sleep(time.Millisecond) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Remove server[0]. + r.NewAddress([]resolver.Address{{Addr: servers[1].addr}, {Addr: servers[2].addr}}) + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + break + } + time.Sleep(time.Millisecond) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Append server[0], nothing should change. + r.NewAddress([]resolver.Address{{Addr: servers[1].addr}, {Addr: servers[2].addr}, {Addr: servers[0].addr}}) + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[1].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 1, err, servers[1].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Remove server[1]. + r.NewAddress([]resolver.Address{{Addr: servers[2].addr}, {Addr: servers[0].addr}}) + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port { + break + } + time.Sleep(time.Millisecond) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[2].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 2, err, servers[2].port) + } + time.Sleep(10 * time.Millisecond) + } + + // Remove server[2]. + r.NewAddress([]resolver.Address{{Addr: servers[0].addr}}) + for i := 0; i < 1000; i++ { + if err = Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[0].port { + break + } + time.Sleep(time.Millisecond) + } + for i := 0; i < 20; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", 0, err, servers[0].port) + } + time.Sleep(10 * time.Millisecond) + } +} diff --git a/resolver/manual/manual.go b/resolver/manual/manual.go new file mode 100644 index 000000000..d9e1efeb2 --- /dev/null +++ b/resolver/manual/manual.go @@ -0,0 +1,82 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package manual contains a resolver for testing purpose only. +package manual + +import ( + "strconv" + "time" + + "google.golang.org/grpc/resolver" +) + +// NewBuilderWithScheme creates a new test resolver builder with the given scheme. +func NewBuilderWithScheme(scheme string) *Resolver { + return &Resolver{ + scheme: scheme, + } +} + +// Resolver is also a resolver builder. +// It's build() function always returns itself. +type Resolver struct { + scheme string + + // Fields actually belong to the resolver. + target string + cc resolver.ClientConn +} + +// Build returns itself for Resolver, because it's both a builder and a resolver. +func (r *Resolver) Build(target string, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) { + r.target = target + r.cc = cc + return r, nil +} + +// Scheme returns the test scheme. +func (r *Resolver) Scheme() string { + return r.scheme +} + +// ResolveNow is a noop for Resolver. +func (*Resolver) ResolveNow(o resolver.ResolveNowOption) {} + +// Close is a noop for Resolver. +func (*Resolver) Close() {} + +// NewAddress calls cc.NewAddress. +func (r *Resolver) NewAddress(addrs []resolver.Address) { + r.cc.NewAddress(addrs) +} + +// NewServiceConfig calls cc.NewServiceConfig. +func (r *Resolver) NewServiceConfig(sc string) { + r.cc.NewServiceConfig(sc) +} + +// GenerateAndRegisterManualResolver generates a random scheme and a Resolver +// with it. It also regieter this Resolver. +// It returns the Resolver and a cleanup function to unregister it. +func GenerateAndRegisterManualResolver() (*Resolver, func()) { + scheme := strconv.FormatInt(time.Now().UnixNano(), 36) + r := NewBuilderWithScheme(scheme) + resolver.Register(r) + return r, func() { resolver.UnregisterForTesting(scheme) } +} diff --git a/resolver/passthrough/passthrough.go b/resolver/passthrough/passthrough.go new file mode 100644 index 000000000..79947c4cc --- /dev/null +++ b/resolver/passthrough/passthrough.go @@ -0,0 +1,63 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package passthrough implements a pass-through resolver. It sends the target +// name without scheme back to gRPC as resolved address. It's for gRPC internal +// test only. +package passthrough + +import ( + "strings" + + "google.golang.org/grpc/resolver" +) + +const scheme = "passthrough" + +type passthroughBuilder struct{} + +func (*passthroughBuilder) Build(target string, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) { + r := &passthroughResolver{ + target: target, + cc: cc, + } + r.start() + return r, nil +} + +func (*passthroughBuilder) Scheme() string { + return scheme +} + +type passthroughResolver struct { + target string + cc resolver.ClientConn +} + +func (r *passthroughResolver) start() { + addr := strings.TrimPrefix(r.target, scheme+":///") + r.cc.NewAddress([]resolver.Address{{Addr: addr}}) +} + +func (*passthroughResolver) ResolveNow(o resolver.ResolveNowOption) {} + +func (*passthroughResolver) Close() {} + +func init() { + resolver.Register(&passthroughBuilder{}) +} diff --git a/resolver/resolver.go b/resolver/resolver.go index 918a6c7db..c15aa9b58 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -126,3 +126,10 @@ type Resolver interface { // Close closes the resolver. Close() } + +// UnregisterForTesting removes the resolver builder with the given scheme from the +// resolver map. +// This function is for testing only. +func UnregisterForTesting(scheme string) { + delete(m, scheme) +} diff --git a/resolver_conn_wrapper.go b/resolver_conn_wrapper.go new file mode 100644 index 000000000..217a59e6a --- /dev/null +++ b/resolver_conn_wrapper.go @@ -0,0 +1,125 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "strings" + + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/resolver" +) + +// ccResolverWrapper is a wrapper on top of cc for resolvers. +// It implements resolver.ClientConnection interface. +type ccResolverWrapper struct { + cc *ClientConn + resolver resolver.Resolver + addrCh chan []resolver.Address + scCh chan string + done chan struct{} +} + +// newCCResolverWrapper parses cc.target for scheme and gets the resolver +// builder for this scheme. It then builds the resolver and starts the +// monitoring goroutine for it. +// +// This function could return nil, nil, in tests for old behaviors. +// TODO(bar) never return nil, nil when DNS becomes the default resolver. +func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { + var scheme string + targetSplitted := strings.Split(cc.target, "://") + if len(targetSplitted) >= 2 { + scheme = targetSplitted[0] + } + grpclog.Infof("dialing to target with scheme: %q", scheme) + + rb := resolver.Get(scheme) + if rb == nil { + // TODO(bar) return error when DNS becomes the default (implemented and + // registered by DNS package). + grpclog.Infof("could not get resolver for scheme: %q", scheme) + return nil, nil + } + + ccr := &ccResolverWrapper{ + cc: cc, + addrCh: make(chan []resolver.Address, 1), + scCh: make(chan string, 1), + done: make(chan struct{}), + } + + var err error + ccr.resolver, err = rb.Build(cc.target, ccr, resolver.BuildOption{}) + if err != nil { + return nil, err + } + go ccr.watcher() + return ccr, nil +} + +// watcher processes address updates and service config updates sequencially. +// Otherwise, we need to resolve possible races between address and service +// config (e.g. they specify different balancer types). +func (ccr *ccResolverWrapper) watcher() { + for { + select { + case <-ccr.done: + return + default: + } + + select { + case addrs := <-ccr.addrCh: + grpclog.Infof("ccResolverWrapper: sending new addresses to balancer wrapper: %v", addrs) + // TODO(bar switching) this should never be nil. Pickfirst should be default. + if ccr.cc.balancerWrapper != nil { + // TODO(bar switching) create balancer if it's nil? + ccr.cc.balancerWrapper.handleResolvedAddrs(addrs, nil) + } + case sc := <-ccr.scCh: + grpclog.Infof("ccResolverWrapper: got new service config: %v", sc) + case <-ccr.done: + return + } + } +} + +func (ccr *ccResolverWrapper) close() { + ccr.resolver.Close() + close(ccr.done) +} + +// NewAddress is called by the resolver implemenetion to send addresses to gRPC. +func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { + select { + case <-ccr.addrCh: + default: + } + ccr.addrCh <- addrs +} + +// NewServiceConfig is called by the resolver implemenetion to send service +// configs to gPRC. +func (ccr *ccResolverWrapper) NewServiceConfig(sc string) { + select { + case <-ccr.scCh: + default: + } + ccr.scCh <- sc +} diff --git a/stream.go b/stream.go index 9a1965a47..75eab40b1 100644 --- a/stream.go +++ b/stream.go @@ -107,7 +107,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth var ( t transport.ClientTransport s *transport.Stream - put func(balancer.DoneInfo) + done func(balancer.DoneInfo) cancel context.CancelFunc ) c := defaultCallInfo() @@ -189,11 +189,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } }() } - gopts := BalancerGetOptions{ - BlockingWait: !c.failFast, - } for { - t, put, err = cc.getTransport(ctx, gopts) + t, done, err = cc.getTransport(ctx, c.failFast) if err != nil { // TODO(zhaoq): Probably revisit the error handling. if _, ok := status.FromError(err); ok { @@ -211,15 +208,15 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth s, err = t.NewStream(ctx, callHdr) if err != nil { - if _, ok := err.(transport.ConnectionError); ok && put != nil { + if _, ok := err.(transport.ConnectionError); ok && done != nil { // If error is connection error, transport was sending data on wire, // and we are not sure if anything has been sent on wire. // If error is not connection error, we are sure nothing has been sent. updateRPCInfoInContext(ctx, rpcInfo{bytesSent: true, bytesReceived: false}) } - if put != nil { - put(balancer.DoneInfo{Err: err}) - put = nil + if done != nil { + done(balancer.DoneInfo{Err: err}) + done = nil } if _, ok := err.(transport.ConnectionError); (ok || err == transport.ErrStreamDrain) && !c.failFast { continue @@ -241,10 +238,10 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth dc: cc.dopts.dc, cancel: cancel, - put: put, - t: t, - s: s, - p: &parser{r: s}, + done: done, + t: t, + s: s, + p: &parser{r: s}, tracing: EnableTracing, trInfo: trInfo, @@ -294,7 +291,7 @@ type clientStream struct { tracing bool // set to EnableTracing when the clientStream is created. mu sync.Mutex - put func(balancer.DoneInfo) + done func(balancer.DoneInfo) closed bool finished bool // trInfo.tr is set when the clientStream is created (if EnableTracing is true), @@ -488,13 +485,13 @@ func (cs *clientStream) finish(err error) { for _, o := range cs.opts { o.after(cs.c) } - if cs.put != nil { + if cs.done != nil { updateRPCInfoInContext(cs.s.Context(), rpcInfo{ bytesSent: cs.s.BytesSent(), bytesReceived: cs.s.BytesReceived(), }) - cs.put(balancer.DoneInfo{Err: err}) - cs.put = nil + cs.done(balancer.DoneInfo{Err: err}) + cs.done = nil } if cs.statsHandler != nil { end := &stats.End{ diff --git a/test/end2end_test.go b/test/end2end_test.go index 983a9b7e1..dd0c34308 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -46,6 +46,8 @@ import ( "golang.org/x/net/http2" spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + _ "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" @@ -55,6 +57,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + _ "google.golang.org/grpc/resolver/passthrough" "google.golang.org/grpc/stats" "google.golang.org/grpc/status" "google.golang.org/grpc/tap" @@ -371,7 +374,7 @@ type env struct { network string // The type of network such as tcp, unix, etc. security string // The security protocol such as TLS, SSH, etc. httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS - balancer bool // whether to use balancer + balancer string // One of "roundrobin", "pickfirst", "v1", or "". customDialer func(string, string, time.Duration) (net.Conn, error) } @@ -390,13 +393,13 @@ func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) { } var ( - tcpClearEnv = env{name: "tcp-clear", network: "tcp", balancer: true} - tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls", balancer: true} - unixClearEnv = env{name: "unix-clear", network: "unix", balancer: true} - unixTLSEnv = env{name: "unix-tls", network: "unix", security: "tls", balancer: true} - handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true, balancer: true} - noBalancerEnv = env{name: "no-balancer", network: "tcp", security: "tls", balancer: false} - allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv, noBalancerEnv} + tcpClearEnv = env{name: "tcp-clear-v1-balancer", network: "tcp", balancer: "v1"} + tcpTLSEnv = env{name: "tcp-tls-v1-balancer", network: "tcp", security: "tls", balancer: "v1"} + tcpClearRREnv = env{name: "tcp-clear", network: "tcp", balancer: "roundrobin"} + tcpTLSRREnv = env{name: "tcp-tls", network: "tcp", security: "tls", balancer: "roundrobin"} + handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true, balancer: "roundrobin"} + noBalancerEnv = env{name: "no-balancer", network: "tcp", security: "tls"} + allEnv = []env{tcpClearEnv, tcpTLSEnv, tcpClearRREnv, tcpTLSRREnv, handlerEnv, noBalancerEnv} ) var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', 'unix-tls', or 'handler-tls' to only run the tests for that environment. Empty means all.") @@ -638,8 +641,18 @@ func (te *test) clientConn() *grpc.ClientConn { default: opts = append(opts, grpc.WithInsecure()) } - if te.e.balancer { + // TODO(bar) switch balancer case "pickfirst". + var scheme string + switch te.e.balancer { + case "v1": opts = append(opts, grpc.WithBalancer(grpc.RoundRobin(nil))) + case "roundrobin": + rr := balancer.Get("roundrobin") + if rr == nil { + te.t.Fatalf("got nil when trying to get roundrobin balancer builder") + } + opts = append(opts, grpc.WithBalancerBuilder(rr)) + scheme = "passthrough:///" } if te.clientInitialWindowSize > 0 { opts = append(opts, grpc.WithInitialWindowSize(te.clientInitialWindowSize)) @@ -658,9 +671,9 @@ func (te *test) clientConn() *grpc.ClientConn { opts = append(opts, grpc.WithBlock()) } var err error - te.cc, err = grpc.Dial(te.srvAddr, opts...) + te.cc, err = grpc.Dial(scheme+te.srvAddr, opts...) if err != nil { - te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err) + te.t.Fatalf("Dial(%q) = %v", scheme+te.srvAddr, err) } return te.cc } @@ -760,7 +773,7 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)) cancel() - if e.balancer && grpc.Code(err) != codes.DeadlineExceeded { + if e.balancer != "" && grpc.Code(err) != codes.DeadlineExceeded { // If e.balancer == nil, the ac will stop reconnecting because the dialer returns non-temp error, // the error will be an internal error. t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %s", ctx, err, codes.DeadlineExceeded) @@ -4078,7 +4091,7 @@ func (c clientAlwaysFailCred) OverrideServerName(s string) error { } func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true}) + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: "v1"}) te.startServer(&testServer{security: te.e.security}) defer te.tearDown() @@ -4093,61 +4106,6 @@ func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { } } -func TestFailFastRPCErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true}) - te.startServer(&testServer{security: te.e.security}) - defer te.tearDown() - - te.nonBlockingDial = true // Connection will never be successful because server is not set up correctly. - cc := te.clientConn() - tc := testpb.NewTestServiceClient(cc) - var err error - for i := 0; i < 1000; i++ { - // This loop runs for at most 1 second. - // The first several RPCs will fail with Unavailable because the connection hasn't started. - // When the first connection failed with creds error, the next RPC should also fail with the expected error. - if _, err = tc.EmptyCall(context.Background(), &testpb.Empty{}); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { - return - } - time.Sleep(time.Millisecond) - } - te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) -} - -func TestFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: false}) - te.startServer(&testServer{security: te.e.security}) - defer te.tearDown() - - te.nonBlockingDial = true - cc := te.clientConn() - tc := testpb.NewTestServiceClient(cc) - var err error - for i := 0; i < 1000; i++ { - // This loop runs for at most 1 second. - // The first several RPCs will fail with Unavailable because the connection hasn't started. - // When the first connection failed with creds error, the next RPC should also fail with the expected error. - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { - return - } - time.Sleep(time.Millisecond) - } - te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) -} - -func TestNonFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { - te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: false}) - te.startServer(&testServer{security: te.e.security}) - defer te.tearDown() - - te.nonBlockingDial = true - cc := te.clientConn() - tc := testpb.NewTestServiceClient(cc) - if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { - te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) - } -} - type clientTimeoutCreds struct { timeoutReturned bool } @@ -4173,7 +4131,7 @@ func (c *clientTimeoutCreds) OverrideServerName(s string) error { } func TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) { - te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "clientTimeoutCreds", balancer: false}) + te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "clientTimeoutCreds", balancer: "v1"}) te.userAgent = testAppUA te.startServer(&testServer{security: te.e.security}) defer te.tearDown()