Allow gRPC clients to connect to multiple backends (#1918)
Fixes #1917 and #1755, also updates google.golang.org/grpc to b60d3e9e.
This commit is contained in:
		
							parent
							
								
									f04b922aff
								
							
						
					
					
						commit
						92e0704b1b
					
				| 
						 | 
				
			
			@ -200,39 +200,39 @@
 | 
			
		|||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"ImportPath": "google.golang.org/grpc",
 | 
			
		||||
			"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
 | 
			
		||||
			"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"ImportPath": "google.golang.org/grpc/codes",
 | 
			
		||||
			"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
 | 
			
		||||
			"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"ImportPath": "google.golang.org/grpc/credentials",
 | 
			
		||||
			"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
 | 
			
		||||
			"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"ImportPath": "google.golang.org/grpc/grpclog",
 | 
			
		||||
			"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
 | 
			
		||||
			"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"ImportPath": "google.golang.org/grpc/internal",
 | 
			
		||||
			"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
 | 
			
		||||
			"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"ImportPath": "google.golang.org/grpc/metadata",
 | 
			
		||||
			"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
 | 
			
		||||
			"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"ImportPath": "google.golang.org/grpc/naming",
 | 
			
		||||
			"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
 | 
			
		||||
			"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"ImportPath": "google.golang.org/grpc/peer",
 | 
			
		||||
			"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
 | 
			
		||||
			"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"ImportPath": "google.golang.org/grpc/transport",
 | 
			
		||||
			"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
 | 
			
		||||
			"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"ImportPath": "gopkg.in/gorp.v1",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,14 +13,14 @@ import (
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	addr := flag.String("addr", "127.0.0.1:9090", "CCS address")
 | 
			
		||||
	addr := flag.String("addr", "boulder:9090", "CCS address")
 | 
			
		||||
	name := flag.String("name", "", "Name to check")
 | 
			
		||||
	issuer := flag.String("issuerDomain", "", "Issuer domain to check against")
 | 
			
		||||
	flag.Parse()
 | 
			
		||||
 | 
			
		||||
	// Set up a connection to the server.
 | 
			
		||||
	conn, err := bgrpc.ClientSetup(&cmd.GRPCClientConfig{
 | 
			
		||||
		ServerAddress:         *addr,
 | 
			
		||||
		ServerAddresses:       []string{*addr},
 | 
			
		||||
		ServerIssuerPath:      "test/grpc-creds/ca.pem",
 | 
			
		||||
		ClientCertificatePath: "test/grpc-creds/client.pem",
 | 
			
		||||
		ClientKeyPath:         "test/grpc-creds/key.pem",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -511,7 +511,7 @@ type LogDescription struct {
 | 
			
		|||
 | 
			
		||||
// GRPCClientConfig contains the information needed to talk to the gRPC service
 | 
			
		||||
type GRPCClientConfig struct {
 | 
			
		||||
	ServerAddress         string
 | 
			
		||||
	ServerAddresses       []string
 | 
			
		||||
	ServerIssuerPath      string
 | 
			
		||||
	ClientCertificatePath string
 | 
			
		||||
	ClientKeyPath         string
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,47 @@
 | 
			
		|||
package grpc
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"google.golang.org/grpc/naming"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// staticResolver implements both the naming.Resolver and naming.Watcher
 | 
			
		||||
// interfaces. It always returns a single static list then blocks forever
 | 
			
		||||
type staticResolver struct {
 | 
			
		||||
	addresses []*naming.Update
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newStaticResolver(addresses []string) *staticResolver {
 | 
			
		||||
	sr := &staticResolver{}
 | 
			
		||||
	for _, a := range addresses {
 | 
			
		||||
		sr.addresses = append(sr.addresses, &naming.Update{
 | 
			
		||||
			Op:   naming.Add,
 | 
			
		||||
			Addr: a,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return sr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Resolve just returns the staticResolver it was called from as it satisfies
 | 
			
		||||
// both the naming.Resolver and naming.Watcher interfaces
 | 
			
		||||
func (sr *staticResolver) Resolve(target string) (naming.Watcher, error) {
 | 
			
		||||
	return sr, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Next is called in a loop by grpc.RoundRobin expecting updates to which addresses are
 | 
			
		||||
// appropriate. Since we just want to return a static list once return a list on the first
 | 
			
		||||
// call then block forever on the second instead of sitting in a tight loop
 | 
			
		||||
func (sr *staticResolver) Next() ([]*naming.Update, error) {
 | 
			
		||||
	if sr.addresses != nil {
 | 
			
		||||
		addrs := sr.addresses
 | 
			
		||||
		sr.addresses = nil
 | 
			
		||||
		return addrs, nil
 | 
			
		||||
	}
 | 
			
		||||
	// Since staticResolver.Next is called in a tight loop block forever
 | 
			
		||||
	// after returning the initial set of addresses
 | 
			
		||||
	forever := make(chan struct{})
 | 
			
		||||
	<-forever
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close does nothing
 | 
			
		||||
func (sr *staticResolver) Close() {}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,39 @@
 | 
			
		|||
package grpc
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"google.golang.org/grpc/naming"
 | 
			
		||||
 | 
			
		||||
	"github.com/letsencrypt/boulder/test"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestStaticResolver(t *testing.T) {
 | 
			
		||||
	names := []string{"test:443"}
 | 
			
		||||
	sr := newStaticResolver(names)
 | 
			
		||||
	watcher, err := sr.Resolve("")
 | 
			
		||||
	test.AssertNotError(t, err, "staticResolver.Resolve failed")
 | 
			
		||||
 | 
			
		||||
	// Make sure doing this doesn't break anything (since it does nothing)
 | 
			
		||||
	watcher.Close()
 | 
			
		||||
 | 
			
		||||
	updates, err := watcher.Next()
 | 
			
		||||
	test.AssertNotError(t, err, "staticwatcher.Next failed")
 | 
			
		||||
	test.AssertEquals(t, len(names), len(updates))
 | 
			
		||||
	test.AssertEquals(t, updates[0].Addr, "test:443")
 | 
			
		||||
	test.AssertEquals(t, updates[0].Op, naming.Add)
 | 
			
		||||
	test.AssertEquals(t, updates[0].Metadata, nil)
 | 
			
		||||
 | 
			
		||||
	returned := make(chan struct{}, 1)
 | 
			
		||||
	go func() {
 | 
			
		||||
		_, err = watcher.Next()
 | 
			
		||||
		test.AssertNotError(t, err, "watcher.Next failed")
 | 
			
		||||
		returned <- struct{}{}
 | 
			
		||||
	}()
 | 
			
		||||
	select {
 | 
			
		||||
	case <-returned:
 | 
			
		||||
		t.Fatal("staticWatcher.Next returned something after the first call")
 | 
			
		||||
	case <-time.After(time.Millisecond * 500):
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,77 @@
 | 
			
		|||
package creds
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
	"google.golang.org/grpc/credentials"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// transportCredentials is a grpc/credentials.TransportCredentials which supports
 | 
			
		||||
// connecting to, and verifying multiple DNS names
 | 
			
		||||
type transportCredentials struct {
 | 
			
		||||
	roots   *x509.CertPool
 | 
			
		||||
	clients []tls.Certificate
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New returns a new initialized grpc/credentials.TransportCredentials
 | 
			
		||||
func New(rootCAs *x509.CertPool, clientCerts []tls.Certificate) credentials.TransportCredentials {
 | 
			
		||||
	return &transportCredentials{rootCAs, clientCerts}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ClientHandshake performs the TLS handshake for a client -> server connection
 | 
			
		||||
func (tc *transportCredentials) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, credentials.AuthInfo, error) {
 | 
			
		||||
	host, _, err := net.SplitHostPort(addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
	conn := tls.Client(rawConn, &tls.Config{
 | 
			
		||||
		ServerName:   host,
 | 
			
		||||
		RootCAs:      tc.roots,
 | 
			
		||||
		Certificates: tc.clients,
 | 
			
		||||
		MinVersion:   tls.VersionTLS12, // Override default of tls.VersionTLS10
 | 
			
		||||
		MaxVersion:   tls.VersionTLS12, // Same as default in golang <= 1.6
 | 
			
		||||
	})
 | 
			
		||||
	errChan := make(chan error, 1)
 | 
			
		||||
	go func() {
 | 
			
		||||
		errChan <- conn.Handshake()
 | 
			
		||||
	}()
 | 
			
		||||
	select {
 | 
			
		||||
	case <-time.After(timeout):
 | 
			
		||||
		return nil, nil, errors.New("boulder/grpc/creds: TLS handshake timed out")
 | 
			
		||||
	case err := <-errChan:
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			_ = rawConn.Close()
 | 
			
		||||
			return nil, nil, fmt.Errorf("boulder/grpc/creds: TLS handshake failed: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
		return conn, nil, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ServerHandshake performs the TLS handshake for a server <- client connection
 | 
			
		||||
func (tc *transportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
 | 
			
		||||
	return nil, nil, fmt.Errorf("boulder/grpc/creds: Server-side handshakes are not implemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Info returns information about the transport protocol used
 | 
			
		||||
func (tc *transportCredentials) Info() credentials.ProtocolInfo {
 | 
			
		||||
	return credentials.ProtocolInfo{
 | 
			
		||||
		SecurityProtocol: "tls",
 | 
			
		||||
		SecurityVersion:  "1.2", // We *only* support TLS 1.2
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetRequestMetadata returns nil, nil since TLS credentials do not have metadata.
 | 
			
		||||
func (tc *transportCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequireTransportSecurity always returns true because TLS is transport security
 | 
			
		||||
func (tc *transportCredentials) RequireTransportSecurity() bool {
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,103 @@
 | 
			
		|||
package creds
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"crypto/rsa"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"crypto/x509/pkix"
 | 
			
		||||
	"math/big"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/letsencrypt/boulder/test"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestTransportCredentials(t *testing.T) {
 | 
			
		||||
	priv, err := rsa.GenerateKey(rand.Reader, 1024)
 | 
			
		||||
	test.AssertNotError(t, err, "rsa.GenerateKey failed")
 | 
			
		||||
 | 
			
		||||
	temp := &x509.Certificate{
 | 
			
		||||
		SerialNumber: big.NewInt(1),
 | 
			
		||||
		Subject: pkix.Name{
 | 
			
		||||
			CommonName: "A",
 | 
			
		||||
		},
 | 
			
		||||
		NotBefore:             time.Unix(1000, 0),
 | 
			
		||||
		NotAfter:              time.Now().AddDate(1, 0, 0),
 | 
			
		||||
		BasicConstraintsValid: true,
 | 
			
		||||
		IsCA: true,
 | 
			
		||||
	}
 | 
			
		||||
	derA, err := x509.CreateCertificate(rand.Reader, temp, temp, priv.Public(), priv)
 | 
			
		||||
	test.AssertNotError(t, err, "x509.CreateCertificate failed")
 | 
			
		||||
	certA, err := x509.ParseCertificate(derA)
 | 
			
		||||
	test.AssertNotError(t, err, "x509.ParserCertificate failed")
 | 
			
		||||
	temp.Subject.CommonName = "B"
 | 
			
		||||
	derB, err := x509.CreateCertificate(rand.Reader, temp, temp, priv.Public(), priv)
 | 
			
		||||
	test.AssertNotError(t, err, "x509.CreateCertificate failed")
 | 
			
		||||
	certB, err := x509.ParseCertificate(derB)
 | 
			
		||||
	test.AssertNotError(t, err, "x509.ParserCertificate failed")
 | 
			
		||||
	roots := x509.NewCertPool()
 | 
			
		||||
	roots.AddCert(certA)
 | 
			
		||||
	roots.AddCert(certB)
 | 
			
		||||
 | 
			
		||||
	serverA := httptest.NewUnstartedServer(nil)
 | 
			
		||||
	serverA.TLS = &tls.Config{Certificates: []tls.Certificate{{Certificate: [][]byte{derA}, PrivateKey: priv}}}
 | 
			
		||||
	serverB := httptest.NewUnstartedServer(nil)
 | 
			
		||||
	serverB.TLS = &tls.Config{Certificates: []tls.Certificate{{Certificate: [][]byte{derB}, PrivateKey: priv}}}
 | 
			
		||||
 | 
			
		||||
	tc := New(roots, nil)
 | 
			
		||||
 | 
			
		||||
	serverA.StartTLS()
 | 
			
		||||
	defer serverA.Close()
 | 
			
		||||
	addrA := serverA.Listener.Addr().String()
 | 
			
		||||
	rawConnA, err := net.Dial("tcp", addrA)
 | 
			
		||||
	test.AssertNotError(t, err, "net.Dial failed")
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = rawConnA.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	conn, _, err := tc.ClientHandshake("A:2020", rawConnA, time.Second)
 | 
			
		||||
	test.AssertNotError(t, err, "tc.ClientHandshake failed")
 | 
			
		||||
	test.Assert(t, conn != nil, "tc.ClientHandshake returned a nil net.Conn")
 | 
			
		||||
 | 
			
		||||
	serverB.StartTLS()
 | 
			
		||||
	defer serverB.Close()
 | 
			
		||||
	addrB := serverB.Listener.Addr().String()
 | 
			
		||||
	rawConnB, err := net.Dial("tcp", addrB)
 | 
			
		||||
	test.AssertNotError(t, err, "net.Dial failed")
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = rawConnB.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	conn, _, err = tc.ClientHandshake("B:3030", rawConnB, time.Second)
 | 
			
		||||
	test.AssertNotError(t, err, "tc.ClientHandshake failed")
 | 
			
		||||
	test.Assert(t, conn != nil, "tc.ClientHandshake returned a nil net.Conn")
 | 
			
		||||
 | 
			
		||||
	// Test timeout
 | 
			
		||||
	ln, err := net.Listen("tcp", "127.0.0.1:0")
 | 
			
		||||
	test.AssertNotError(t, err, "net.Listen failed")
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = ln.Close()
 | 
			
		||||
	}()
 | 
			
		||||
	addrC := ln.Addr().String()
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			_, err := ln.Accept()
 | 
			
		||||
			test.AssertNotError(t, err, "ln.Accept failed")
 | 
			
		||||
			time.Sleep(time.Second)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	rawConnC, err := net.Dial("tcp", addrC)
 | 
			
		||||
	test.AssertNotError(t, err, "net.Dial failed")
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = rawConnB.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	conn, _, err = tc.ClientHandshake("A:2020", rawConnC, time.Millisecond)
 | 
			
		||||
	test.AssertError(t, err, "tc.ClientHandshake didn't timeout")
 | 
			
		||||
	test.AssertEquals(t, err.Error(), "boulder/grpc/creds: TLS handshake timed out")
 | 
			
		||||
	test.Assert(t, conn == nil, "tc.ClientHandshake returned a non-nil net.Conn on failure")
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										20
									
								
								grpc/util.go
								
								
								
								
							
							
						
						
									
										20
									
								
								grpc/util.go
								
								
								
								
							| 
						 | 
				
			
			@ -12,17 +12,21 @@ import (
 | 
			
		|||
	"google.golang.org/grpc/credentials"
 | 
			
		||||
 | 
			
		||||
	"github.com/letsencrypt/boulder/cmd"
 | 
			
		||||
	bcreds "github.com/letsencrypt/boulder/grpc/creds"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CodedError is a alias required to appease go vet
 | 
			
		||||
var CodedError = grpc.Errorf
 | 
			
		||||
 | 
			
		||||
// ClientSetup loads various TLS certificates and creates a
 | 
			
		||||
// gRPC TransportAuthenticator that presents the client certificate
 | 
			
		||||
// gRPC TransportCredentials that presents the client certificate
 | 
			
		||||
// and validates the certificate presented by the server is for a
 | 
			
		||||
// specific hostname and issued by the provided issuer certificate
 | 
			
		||||
// thens dials and returns a grpc.ClientConn to the remote service.
 | 
			
		||||
func ClientSetup(c *cmd.GRPCClientConfig) (*grpc.ClientConn, error) {
 | 
			
		||||
	if len(c.ServerAddresses) == 0 {
 | 
			
		||||
		return nil, fmt.Errorf("boulder/grpc: ServerAddresses is empty")
 | 
			
		||||
	}
 | 
			
		||||
	serverIssuerBytes, err := ioutil.ReadFile(c.ServerIssuerPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
| 
						 | 
				
			
			@ -35,15 +39,11 @@ func ClientSetup(c *cmd.GRPCClientConfig) (*grpc.ClientConn, error) {
 | 
			
		|||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	host, _, err := net.SplitHostPort(c.ServerAddress)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return grpc.Dial(c.ServerAddress, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
 | 
			
		||||
		ServerName:   host,
 | 
			
		||||
		RootCAs:      rootCAs,
 | 
			
		||||
		Certificates: []tls.Certificate{clientCert},
 | 
			
		||||
	})))
 | 
			
		||||
	return grpc.Dial(
 | 
			
		||||
		"", // Since our staticResolver provides addresses we don't need to pass an address here
 | 
			
		||||
		grpc.WithTransportCredentials(bcreds.New(rootCAs, []tls.Certificate{clientCert})),
 | 
			
		||||
		grpc.WithBalancer(grpc.RoundRobin(newStaticResolver(c.ServerAddresses))),
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewServer loads various TLS certificates and creates a
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -137,7 +137,7 @@
 | 
			
		|||
    },
 | 
			
		||||
    "maxConcurrentRPCServerRequests": 16,
 | 
			
		||||
    "publisherService": {
 | 
			
		||||
      "serverAddress": "boulder:9091",
 | 
			
		||||
      "serverAddresses": ["boulder:9091"],
 | 
			
		||||
      "serverIssuerPath": "test/grpc-creds/ca.pem",
 | 
			
		||||
      "clientCertificatePath": "test/grpc-creds/client.pem",
 | 
			
		||||
      "clientKeyPath": "test/grpc-creds/key.pem",
 | 
			
		||||
| 
						 | 
				
			
			@ -173,7 +173,7 @@
 | 
			
		|||
    "doNotForceCN": true,
 | 
			
		||||
    "reuseValidAuthz": true,
 | 
			
		||||
    "vaService": {
 | 
			
		||||
      "serverAddress": "boulder:9092",
 | 
			
		||||
      "serverAddresses": ["boulder:9092"],
 | 
			
		||||
      "serverIssuerPath": "test/grpc-creds/ca.pem",
 | 
			
		||||
      "clientCertificatePath": "test/grpc-creds/client.pem",
 | 
			
		||||
      "clientKeyPath": "test/grpc-creds/key.pem",
 | 
			
		||||
| 
						 | 
				
			
			@ -224,7 +224,7 @@
 | 
			
		|||
    "dnsTries": 3,
 | 
			
		||||
    "issuerDomain": "happy-hacker-ca.invalid",
 | 
			
		||||
    "caaService": {
 | 
			
		||||
      "serverAddress": "boulder:9090",
 | 
			
		||||
      "serverAddresses": ["boulder:9090"],
 | 
			
		||||
      "serverIssuerPath": "test/grpc-creds/ca.pem",
 | 
			
		||||
      "clientCertificatePath": "test/grpc-creds/client.pem",
 | 
			
		||||
      "clientKeyPath": "test/grpc-creds/key.pem"
 | 
			
		||||
| 
						 | 
				
			
			@ -297,7 +297,7 @@
 | 
			
		|||
    "signFailureBackoffMax": "30m",
 | 
			
		||||
    "debugAddr": "localhost:8006",
 | 
			
		||||
    "publisher": {
 | 
			
		||||
      "serverAddress": "boulder:9091",
 | 
			
		||||
      "serverAddresses": ["boulder:9091"],
 | 
			
		||||
      "serverIssuerPath": "test/grpc-creds/ca.pem",
 | 
			
		||||
      "clientCertificatePath": "test/grpc-creds/client.pem",
 | 
			
		||||
      "clientKeyPath": "test/grpc-creds/key.pem",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,8 +21,9 @@ proto:
 | 
			
		|||
		exit 1; \
 | 
			
		||||
	fi
 | 
			
		||||
	go get -u -v github.com/golang/protobuf/protoc-gen-go
 | 
			
		||||
	for file in $$(git ls-files '*.proto'); do \
 | 
			
		||||
		protoc -I $$(dirname $$file) --go_out=plugins=grpc:$$(dirname $$file) $$file; \
 | 
			
		||||
	# use $$dir as the root for all proto files in the same directory
 | 
			
		||||
	for dir in $$(git ls-files '*.proto' | xargs -n1 dirname | uniq); do \
 | 
			
		||||
		protoc -I $$dir --go_out=plugins=grpc:$$dir $$dir/*.proto; \
 | 
			
		||||
	done
 | 
			
		||||
 | 
			
		||||
test: testdeps
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,22 +1,22 @@
 | 
			
		|||
Additional IP Rights Grant (Patents)
 | 
			
		||||
 | 
			
		||||
"This implementation" means the copyrightable works distributed by
 | 
			
		||||
Google as part of the GRPC project.
 | 
			
		||||
Google as part of the gRPC project.
 | 
			
		||||
 | 
			
		||||
Google hereby grants to You a perpetual, worldwide, non-exclusive,
 | 
			
		||||
no-charge, royalty-free, irrevocable (except as stated in this section)
 | 
			
		||||
patent license to make, have made, use, offer to sell, sell, import,
 | 
			
		||||
transfer and otherwise run, modify and propagate the contents of this
 | 
			
		||||
implementation of GRPC, where such license applies only to those patent
 | 
			
		||||
implementation of gRPC, where such license applies only to those patent
 | 
			
		||||
claims, both currently owned or controlled by Google and acquired in
 | 
			
		||||
the future, licensable by Google that are necessarily infringed by this
 | 
			
		||||
implementation of GRPC.  This grant does not include claims that would be
 | 
			
		||||
implementation of gRPC.  This grant does not include claims that would be
 | 
			
		||||
infringed only as a consequence of further modification of this
 | 
			
		||||
implementation.  If you or your agent or exclusive licensee institute or
 | 
			
		||||
order or agree to the institution of patent litigation against any
 | 
			
		||||
entity (including a cross-claim or counterclaim in a lawsuit) alleging
 | 
			
		||||
that this implementation of GRPC or any code incorporated within this
 | 
			
		||||
implementation of GRPC constitutes direct or contributory patent
 | 
			
		||||
that this implementation of gRPC or any code incorporated within this
 | 
			
		||||
implementation of gRPC constitutes direct or contributory patent
 | 
			
		||||
infringement, or inducement of patent infringement, then any patent
 | 
			
		||||
rights granted to you under this License for this implementation of GRPC
 | 
			
		||||
rights granted to you under this License for this implementation of gRPC
 | 
			
		||||
shall terminate as of the date such litigation is filed.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,7 +19,7 @@ var (
 | 
			
		|||
// backoffStrategy defines the methodology for backing off after a grpc
 | 
			
		||||
// connection failure.
 | 
			
		||||
//
 | 
			
		||||
// This is unexported until the GRPC project decides whether or not to allow
 | 
			
		||||
// This is unexported until the gRPC project decides whether or not to allow
 | 
			
		||||
// alternative backoff strategies. Once a decision is made, this type and its
 | 
			
		||||
// method may be exported.
 | 
			
		||||
type backoffStrategy interface {
 | 
			
		||||
| 
						 | 
				
			
			@ -28,14 +28,14 @@ type backoffStrategy interface {
 | 
			
		|||
	backoff(retries int) time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BackoffConfig defines the parameters for the default GRPC backoff strategy.
 | 
			
		||||
// BackoffConfig defines the parameters for the default gRPC backoff strategy.
 | 
			
		||||
type BackoffConfig struct {
 | 
			
		||||
	// MaxDelay is the upper bound of backoff delay.
 | 
			
		||||
	MaxDelay time.Duration
 | 
			
		||||
 | 
			
		||||
	// TODO(stevvooe): The following fields are not exported, as allowing
 | 
			
		||||
	// changes would violate the current GRPC specification for backoff. If
 | 
			
		||||
	// GRPC decides to allow more interesting backoff strategies, these fields
 | 
			
		||||
	// changes would violate the current gRPC specification for backoff. If
 | 
			
		||||
	// gRPC decides to allow more interesting backoff strategies, these fields
 | 
			
		||||
	// may be opened up in the future.
 | 
			
		||||
 | 
			
		||||
	// baseDelay is the amount of time to wait before retrying after the first
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,340 @@
 | 
			
		|||
/*
 | 
			
		||||
 *
 | 
			
		||||
 * Copyright 2016, Google Inc.
 | 
			
		||||
 * All rights reserved.
 | 
			
		||||
 *
 | 
			
		||||
 * Redistribution and use in source and binary forms, with or without
 | 
			
		||||
 * modification, are permitted provided that the following conditions are
 | 
			
		||||
 * met:
 | 
			
		||||
 *
 | 
			
		||||
 *     * Redistributions of source code must retain the above copyright
 | 
			
		||||
 * notice, this list of conditions and the following disclaimer.
 | 
			
		||||
 *     * Redistributions in binary form must reproduce the above
 | 
			
		||||
 * copyright notice, this list of conditions and the following disclaimer
 | 
			
		||||
 * in the documentation and/or other materials provided with the
 | 
			
		||||
 * distribution.
 | 
			
		||||
 *     * Neither the name of Google Inc. nor the names of its
 | 
			
		||||
 * contributors may be used to endorse or promote products derived from
 | 
			
		||||
 * this software without specific prior written permission.
 | 
			
		||||
 *
 | 
			
		||||
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 | 
			
		||||
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 | 
			
		||||
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 | 
			
		||||
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 | 
			
		||||
 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 | 
			
		||||
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 | 
			
		||||
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 | 
			
		||||
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 | 
			
		||||
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 | 
			
		||||
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | 
			
		||||
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package grpc
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
	"google.golang.org/grpc/grpclog"
 | 
			
		||||
	"google.golang.org/grpc/naming"
 | 
			
		||||
	"google.golang.org/grpc/transport"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Address represents a server the client connects to.
 | 
			
		||||
// This is the EXPERIMENTAL API and may be changed or extended in the future.
 | 
			
		||||
type Address struct {
 | 
			
		||||
	// Addr is the server address on which a connection will be established.
 | 
			
		||||
	Addr string
 | 
			
		||||
	// Metadata is the information associated with Addr, which may be used
 | 
			
		||||
	// to make load balancing decision.
 | 
			
		||||
	Metadata interface{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BalancerGetOptions configures a Get call.
 | 
			
		||||
// This is the EXPERIMENTAL API and may be changed or extended in the future.
 | 
			
		||||
type BalancerGetOptions struct {
 | 
			
		||||
	// BlockingWait specifies whether Get should block when there is no
 | 
			
		||||
	// connected address.
 | 
			
		||||
	BlockingWait bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Balancer chooses network addresses for RPCs.
 | 
			
		||||
// This is the EXPERIMENTAL API and may be changed or extended in the future.
 | 
			
		||||
type Balancer interface {
 | 
			
		||||
	// Start does the initialization work to bootstrap a Balancer. For example,
 | 
			
		||||
	// this function may start the name resolution and watch the updates. It will
 | 
			
		||||
	// be called when dialing.
 | 
			
		||||
	Start(target string) error
 | 
			
		||||
	// Up informs the Balancer that gRPC has a connection to the server at
 | 
			
		||||
	// addr. It returns down which is called once the connection to addr gets
 | 
			
		||||
	// lost or closed.
 | 
			
		||||
	// TODO: It is not clear how to construct and take advantage the meaningful error
 | 
			
		||||
	// parameter for down. Need realistic demands to guide.
 | 
			
		||||
	Up(addr Address) (down func(error))
 | 
			
		||||
	// Get gets the address of a server for the RPC corresponding to ctx.
 | 
			
		||||
	// i) If it returns a connected address, gRPC internals issues the RPC on the
 | 
			
		||||
	// connection to this address;
 | 
			
		||||
	// ii) If it returns an address on which the connection is under construction
 | 
			
		||||
	// (initiated by Notify(...)) but not connected, gRPC internals
 | 
			
		||||
	//  * fails RPC if the RPC is fail-fast and connection is in the TransientFailure or
 | 
			
		||||
	//  Shutdown state;
 | 
			
		||||
	//  or
 | 
			
		||||
	//  * issues RPC on the connection otherwise.
 | 
			
		||||
	// iii) If it returns an address on which the connection does not exist, gRPC
 | 
			
		||||
	// internals treats it as an error and will fail the corresponding RPC.
 | 
			
		||||
	//
 | 
			
		||||
	// Therefore, the following is the recommended rule when writing a custom Balancer.
 | 
			
		||||
	// If opts.BlockingWait is true, it should return a connected address or
 | 
			
		||||
	// block if there is no connected address. It should respect the timeout or
 | 
			
		||||
	// cancellation of ctx when blocking. If opts.BlockingWait is false (for fail-fast
 | 
			
		||||
	// RPCs), it should return an address it has notified via Notify(...) immediately
 | 
			
		||||
	// instead of blocking.
 | 
			
		||||
	//
 | 
			
		||||
	// The function returns put which is called once the rpc has completed or failed.
 | 
			
		||||
	// put can collect and report RPC stats to a remote load balancer. gRPC internals
 | 
			
		||||
	// will try to call this again if err is non-nil (unless err is ErrClientConnClosing).
 | 
			
		||||
	//
 | 
			
		||||
	// TODO: Add other non-recoverable errors?
 | 
			
		||||
	Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error)
 | 
			
		||||
	// Notify returns a channel that is used by gRPC internals to watch the addresses
 | 
			
		||||
	// gRPC needs to connect. The addresses might be from a name resolver or remote
 | 
			
		||||
	// load balancer. gRPC internals will compare it with the existing connected
 | 
			
		||||
	// addresses. If the address Balancer notified is not in the existing connected
 | 
			
		||||
	// addresses, gRPC starts to connect the address. If an address in the existing
 | 
			
		||||
	// connected addresses is not in the notification list, the corresponding connection
 | 
			
		||||
	// is shutdown gracefully. Otherwise, there are no operations to take. Note that
 | 
			
		||||
	// the Address slice must be the full list of the Addresses which should be connected.
 | 
			
		||||
	// It is NOT delta.
 | 
			
		||||
	Notify() <-chan []Address
 | 
			
		||||
	// Close shuts down the balancer.
 | 
			
		||||
	Close() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// downErr implements net.Error. It is constructed by gRPC internals and passed to the down
 | 
			
		||||
// call of Balancer.
 | 
			
		||||
type downErr struct {
 | 
			
		||||
	timeout   bool
 | 
			
		||||
	temporary bool
 | 
			
		||||
	desc      string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e downErr) Error() string   { return e.desc }
 | 
			
		||||
func (e downErr) Timeout() bool   { return e.timeout }
 | 
			
		||||
func (e downErr) Temporary() bool { return e.temporary }
 | 
			
		||||
 | 
			
		||||
func downErrorf(timeout, temporary bool, format string, a ...interface{}) downErr {
 | 
			
		||||
	return downErr{
 | 
			
		||||
		timeout:   timeout,
 | 
			
		||||
		temporary: temporary,
 | 
			
		||||
		desc:      fmt.Sprintf(format, a...),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch
 | 
			
		||||
// the name resolution updates and updates the addresses available correspondingly.
 | 
			
		||||
func RoundRobin(r naming.Resolver) Balancer {
 | 
			
		||||
	return &roundRobin{r: r}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type roundRobin struct {
 | 
			
		||||
	r         naming.Resolver
 | 
			
		||||
	w         naming.Watcher
 | 
			
		||||
	open      []Address // all the addresses the client should potentially connect
 | 
			
		||||
	mu        sync.Mutex
 | 
			
		||||
	addrCh    chan []Address // the channel to notify gRPC internals the list of addresses the client should connect to.
 | 
			
		||||
	connected []Address      // all the connected addresses
 | 
			
		||||
	next      int            // index of the next address to return for Get()
 | 
			
		||||
	waitCh    chan struct{}  // the channel to block when there is no connected address available
 | 
			
		||||
	done      bool           // The Balancer is closed.
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rr *roundRobin) watchAddrUpdates() error {
 | 
			
		||||
	updates, err := rr.w.Next()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		grpclog.Println("grpc: the naming watcher stops working due to %v.", err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	rr.mu.Lock()
 | 
			
		||||
	defer rr.mu.Unlock()
 | 
			
		||||
	for _, update := range updates {
 | 
			
		||||
		addr := Address{
 | 
			
		||||
			Addr: update.Addr,
 | 
			
		||||
		}
 | 
			
		||||
		switch update.Op {
 | 
			
		||||
		case naming.Add:
 | 
			
		||||
			var exist bool
 | 
			
		||||
			for _, v := range rr.open {
 | 
			
		||||
				if addr == v {
 | 
			
		||||
					exist = true
 | 
			
		||||
					grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr)
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if exist {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			rr.open = append(rr.open, addr)
 | 
			
		||||
		case naming.Delete:
 | 
			
		||||
			for i, v := range rr.open {
 | 
			
		||||
				if v == addr {
 | 
			
		||||
					copy(rr.open[i:], rr.open[i+1:])
 | 
			
		||||
					rr.open = rr.open[:len(rr.open)-1]
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			grpclog.Println("Unknown update.Op ", update.Op)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	// Make a copy of rr.open and write it onto rr.addrCh so that gRPC internals gets notified.
 | 
			
		||||
	open := make([]Address, len(rr.open), len(rr.open))
 | 
			
		||||
	copy(open, rr.open)
 | 
			
		||||
	if rr.done {
 | 
			
		||||
		return ErrClientConnClosing
 | 
			
		||||
	}
 | 
			
		||||
	rr.addrCh <- open
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rr *roundRobin) Start(target string) error {
 | 
			
		||||
	if rr.r == nil {
 | 
			
		||||
		// If there is no name resolver installed, it is not needed to
 | 
			
		||||
		// do name resolution. In this case, rr.addrCh stays nil.
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	w, err := rr.r.Resolve(target)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	rr.w = w
 | 
			
		||||
	rr.addrCh = make(chan []Address)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			if err := rr.watchAddrUpdates(); err != nil {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Up appends addr to the end of rr.connected and sends notification if there
 | 
			
		||||
// are pending Get() calls.
 | 
			
		||||
func (rr *roundRobin) Up(addr Address) func(error) {
 | 
			
		||||
	rr.mu.Lock()
 | 
			
		||||
	defer rr.mu.Unlock()
 | 
			
		||||
	for _, a := range rr.connected {
 | 
			
		||||
		if a == addr {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	rr.connected = append(rr.connected, addr)
 | 
			
		||||
	if len(rr.connected) == 1 {
 | 
			
		||||
		// addr is only one available. Notify the Get() callers who are blocking.
 | 
			
		||||
		if rr.waitCh != nil {
 | 
			
		||||
			close(rr.waitCh)
 | 
			
		||||
			rr.waitCh = nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return func(err error) {
 | 
			
		||||
		rr.down(addr, err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// down removes addr from rr.connected and moves the remaining addrs forward.
 | 
			
		||||
func (rr *roundRobin) down(addr Address, err error) {
 | 
			
		||||
	rr.mu.Lock()
 | 
			
		||||
	defer rr.mu.Unlock()
 | 
			
		||||
	for i, a := range rr.connected {
 | 
			
		||||
		if a == addr {
 | 
			
		||||
			copy(rr.connected[i:], rr.connected[i+1:])
 | 
			
		||||
			rr.connected = rr.connected[:len(rr.connected)-1]
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get returns the next addr in the rotation.
 | 
			
		||||
func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) {
 | 
			
		||||
	var ch chan struct{}
 | 
			
		||||
	rr.mu.Lock()
 | 
			
		||||
	if rr.done {
 | 
			
		||||
		rr.mu.Unlock()
 | 
			
		||||
		err = ErrClientConnClosing
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if rr.next >= len(rr.connected) {
 | 
			
		||||
		rr.next = 0
 | 
			
		||||
	}
 | 
			
		||||
	if len(rr.connected) > 0 {
 | 
			
		||||
		addr = rr.connected[rr.next]
 | 
			
		||||
		rr.next++
 | 
			
		||||
		rr.mu.Unlock()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// There is no address available. Wait on rr.waitCh.
 | 
			
		||||
	// TODO(zhaoq): Handle the case when opts.BlockingWait is false.
 | 
			
		||||
	if rr.waitCh == nil {
 | 
			
		||||
		ch = make(chan struct{})
 | 
			
		||||
		rr.waitCh = ch
 | 
			
		||||
	} else {
 | 
			
		||||
		ch = rr.waitCh
 | 
			
		||||
	}
 | 
			
		||||
	rr.mu.Unlock()
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ctx.Done():
 | 
			
		||||
			err = transport.ContextErr(ctx.Err())
 | 
			
		||||
			return
 | 
			
		||||
		case <-ch:
 | 
			
		||||
			rr.mu.Lock()
 | 
			
		||||
			if rr.done {
 | 
			
		||||
				rr.mu.Unlock()
 | 
			
		||||
				err = ErrClientConnClosing
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if len(rr.connected) == 0 {
 | 
			
		||||
				// The newly added addr got removed by Down() again.
 | 
			
		||||
				if rr.waitCh == nil {
 | 
			
		||||
					ch = make(chan struct{})
 | 
			
		||||
					rr.waitCh = ch
 | 
			
		||||
				} else {
 | 
			
		||||
					ch = rr.waitCh
 | 
			
		||||
				}
 | 
			
		||||
				rr.mu.Unlock()
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if rr.next >= len(rr.connected) {
 | 
			
		||||
				rr.next = 0
 | 
			
		||||
			}
 | 
			
		||||
			addr = rr.connected[rr.next]
 | 
			
		||||
			rr.next++
 | 
			
		||||
			rr.mu.Unlock()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rr *roundRobin) Notify() <-chan []Address {
 | 
			
		||||
	return rr.addrCh
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rr *roundRobin) Close() error {
 | 
			
		||||
	rr.mu.Lock()
 | 
			
		||||
	defer rr.mu.Unlock()
 | 
			
		||||
	rr.done = true
 | 
			
		||||
	if rr.w != nil {
 | 
			
		||||
		rr.w.Close()
 | 
			
		||||
	}
 | 
			
		||||
	if rr.waitCh != nil {
 | 
			
		||||
		close(rr.waitCh)
 | 
			
		||||
		rr.waitCh = nil
 | 
			
		||||
	}
 | 
			
		||||
	if rr.addrCh != nil {
 | 
			
		||||
		close(rr.addrCh)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -132,19 +132,16 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
 | 
			
		|||
		Last:  true,
 | 
			
		||||
		Delay: false,
 | 
			
		||||
	}
 | 
			
		||||
	var (
 | 
			
		||||
		lastErr error // record the error that happened
 | 
			
		||||
	)
 | 
			
		||||
	for {
 | 
			
		||||
		var (
 | 
			
		||||
			err    error
 | 
			
		||||
			t      transport.ClientTransport
 | 
			
		||||
			stream *transport.Stream
 | 
			
		||||
			// Record the put handler from Balancer.Get(...). It is called once the
 | 
			
		||||
			// RPC has completed or failed.
 | 
			
		||||
			put func()
 | 
			
		||||
		)
 | 
			
		||||
		// TODO(zhaoq): Need a formal spec of retry strategy for non-failfast rpcs.
 | 
			
		||||
		if lastErr != nil && c.failFast {
 | 
			
		||||
			return toRPCErr(lastErr)
 | 
			
		||||
		}
 | 
			
		||||
		// TODO(zhaoq): Need a formal spec of fail-fast.
 | 
			
		||||
		callHdr := &transport.CallHdr{
 | 
			
		||||
			Host:   cc.authority,
 | 
			
		||||
			Method: method,
 | 
			
		||||
| 
						 | 
				
			
			@ -152,39 +149,66 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
 | 
			
		|||
		if cc.dopts.cp != nil {
 | 
			
		||||
			callHdr.SendCompress = cc.dopts.cp.Type()
 | 
			
		||||
		}
 | 
			
		||||
		t, err = cc.dopts.picker.Pick(ctx)
 | 
			
		||||
		gopts := BalancerGetOptions{
 | 
			
		||||
			BlockingWait: !c.failFast,
 | 
			
		||||
		}
 | 
			
		||||
		t, put, err = cc.getTransport(ctx, gopts)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if lastErr != nil {
 | 
			
		||||
				// This was a retry; return the error from the last attempt.
 | 
			
		||||
				return toRPCErr(lastErr)
 | 
			
		||||
			// TODO(zhaoq): Probably revisit the error handling.
 | 
			
		||||
			if err == ErrClientConnClosing {
 | 
			
		||||
				return Errorf(codes.FailedPrecondition, "%v", err)
 | 
			
		||||
			}
 | 
			
		||||
			return toRPCErr(err)
 | 
			
		||||
			if _, ok := err.(transport.StreamError); ok {
 | 
			
		||||
				return toRPCErr(err)
 | 
			
		||||
			}
 | 
			
		||||
			if _, ok := err.(transport.ConnectionError); ok {
 | 
			
		||||
				if c.failFast {
 | 
			
		||||
					return toRPCErr(err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			// All the remaining cases are treated as retryable.
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if c.traceInfo.tr != nil {
 | 
			
		||||
			c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
 | 
			
		||||
		}
 | 
			
		||||
		stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if _, ok := err.(transport.ConnectionError); ok {
 | 
			
		||||
				lastErr = err
 | 
			
		||||
				continue
 | 
			
		||||
			if put != nil {
 | 
			
		||||
				put()
 | 
			
		||||
				put = nil
 | 
			
		||||
			}
 | 
			
		||||
			if lastErr != nil {
 | 
			
		||||
				return toRPCErr(lastErr)
 | 
			
		||||
			if _, ok := err.(transport.ConnectionError); ok {
 | 
			
		||||
				if c.failFast {
 | 
			
		||||
					return toRPCErr(err)
 | 
			
		||||
				}
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			return toRPCErr(err)
 | 
			
		||||
		}
 | 
			
		||||
		// Receive the response
 | 
			
		||||
		lastErr = recvResponse(cc.dopts, t, &c, stream, reply)
 | 
			
		||||
		if _, ok := lastErr.(transport.ConnectionError); ok {
 | 
			
		||||
			continue
 | 
			
		||||
		err = recvResponse(cc.dopts, t, &c, stream, reply)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if put != nil {
 | 
			
		||||
				put()
 | 
			
		||||
				put = nil
 | 
			
		||||
			}
 | 
			
		||||
			if _, ok := err.(transport.ConnectionError); ok {
 | 
			
		||||
				if c.failFast {
 | 
			
		||||
					return toRPCErr(err)
 | 
			
		||||
				}
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			t.CloseStream(stream, err)
 | 
			
		||||
			return toRPCErr(err)
 | 
			
		||||
		}
 | 
			
		||||
		if c.traceInfo.tr != nil {
 | 
			
		||||
			c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true)
 | 
			
		||||
		}
 | 
			
		||||
		t.CloseStream(stream, lastErr)
 | 
			
		||||
		if lastErr != nil {
 | 
			
		||||
			return toRPCErr(lastErr)
 | 
			
		||||
		t.CloseStream(stream, nil)
 | 
			
		||||
		if put != nil {
 | 
			
		||||
			put()
 | 
			
		||||
			put = nil
 | 
			
		||||
		}
 | 
			
		||||
		return Errorf(stream.StatusCode(), "%s", stream.StatusDesc())
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,28 +43,38 @@ import (
 | 
			
		|||
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
	"golang.org/x/net/trace"
 | 
			
		||||
	"google.golang.org/grpc/codes"
 | 
			
		||||
	"google.golang.org/grpc/credentials"
 | 
			
		||||
	"google.golang.org/grpc/grpclog"
 | 
			
		||||
	"google.golang.org/grpc/transport"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// ErrUnspecTarget indicates that the target address is unspecified.
 | 
			
		||||
	ErrUnspecTarget = errors.New("grpc: target is unspecified")
 | 
			
		||||
	// ErrNoTransportSecurity indicates that there is no transport security
 | 
			
		||||
	// ErrClientConnClosing indicates that the operation is illegal because
 | 
			
		||||
	// the ClientConn is closing.
 | 
			
		||||
	ErrClientConnClosing = errors.New("grpc: the client connection is closing")
 | 
			
		||||
	// ErrClientConnTimeout indicates that the ClientConn cannot establish the
 | 
			
		||||
	// underlying connections within the specified timeout.
 | 
			
		||||
	ErrClientConnTimeout = errors.New("grpc: timed out when dialing")
 | 
			
		||||
 | 
			
		||||
	// errNoTransportSecurity indicates that there is no transport security
 | 
			
		||||
	// being set for ClientConn. Users should either set one or explicitly
 | 
			
		||||
	// call WithInsecure DialOption to disable security.
 | 
			
		||||
	ErrNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)")
 | 
			
		||||
	// ErrCredentialsMisuse indicates that users want to transmit security information
 | 
			
		||||
	// (e.g., oauth2 token) which requires secure connection on an insecure
 | 
			
		||||
	errNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)")
 | 
			
		||||
	// errTransportCredentialsMissing indicates that users want to transmit security
 | 
			
		||||
	// information (e.g., oauth2 token) which requires secure connection on an insecure
 | 
			
		||||
	// connection.
 | 
			
		||||
	ErrCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)")
 | 
			
		||||
	// ErrClientConnClosing indicates that the operation is illegal because
 | 
			
		||||
	// the session is closing.
 | 
			
		||||
	ErrClientConnClosing = errors.New("grpc: the client connection is closing")
 | 
			
		||||
	// ErrClientConnTimeout indicates that the connection could not be
 | 
			
		||||
	// established or re-established within the specified timeout.
 | 
			
		||||
	ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
 | 
			
		||||
	errTransportCredentialsMissing = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportCredentials() to set)")
 | 
			
		||||
	// errCredentialsConflict indicates that grpc.WithTransportCredentials()
 | 
			
		||||
	// and grpc.WithInsecure() are both called for a connection.
 | 
			
		||||
	errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
 | 
			
		||||
	// errNetworkIP indicates that the connection is down due to some network I/O error.
 | 
			
		||||
	errNetworkIO = errors.New("grpc: failed with network I/O error")
 | 
			
		||||
	// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
 | 
			
		||||
	errConnDrain = errors.New("grpc: the connection is drained")
 | 
			
		||||
	// errConnClosing indicates that the connection is closing.
 | 
			
		||||
	errConnClosing = errors.New("grpc: the connection is closing")
 | 
			
		||||
	errNoAddr      = errors.New("grpc: there is no address available to dial")
 | 
			
		||||
	// minimum time to give a connection to complete
 | 
			
		||||
	minConnectTimeout = 20 * time.Second
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -76,9 +86,10 @@ type dialOptions struct {
 | 
			
		|||
	cp       Compressor
 | 
			
		||||
	dc       Decompressor
 | 
			
		||||
	bs       backoffStrategy
 | 
			
		||||
	picker   Picker
 | 
			
		||||
	balancer Balancer
 | 
			
		||||
	block    bool
 | 
			
		||||
	insecure bool
 | 
			
		||||
	timeout  time.Duration
 | 
			
		||||
	copts    transport.ConnectOptions
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -108,10 +119,10 @@ func WithDecompressor(dc Decompressor) DialOption {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithPicker returns a DialOption which sets a picker for connection selection.
 | 
			
		||||
func WithPicker(p Picker) DialOption {
 | 
			
		||||
// WithBalancer returns a DialOption which sets a load balancer.
 | 
			
		||||
func WithBalancer(b Balancer) DialOption {
 | 
			
		||||
	return func(o *dialOptions) {
 | 
			
		||||
		o.picker = p
 | 
			
		||||
		o.balancer = b
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -136,7 +147,7 @@ func WithBackoffConfig(b BackoffConfig) DialOption {
 | 
			
		|||
// withBackoff sets the backoff strategy used for retries after a
 | 
			
		||||
// failed connection attempt.
 | 
			
		||||
//
 | 
			
		||||
// This can be exported if arbitrary backoff strategies are allowed by GRPC.
 | 
			
		||||
// This can be exported if arbitrary backoff strategies are allowed by gRPC.
 | 
			
		||||
func withBackoff(bs backoffStrategy) DialOption {
 | 
			
		||||
	return func(o *dialOptions) {
 | 
			
		||||
		o.bs = bs
 | 
			
		||||
| 
						 | 
				
			
			@ -162,24 +173,25 @@ func WithInsecure() DialOption {
 | 
			
		|||
 | 
			
		||||
// WithTransportCredentials returns a DialOption which configures a
 | 
			
		||||
// connection level security credentials (e.g., TLS/SSL).
 | 
			
		||||
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
 | 
			
		||||
func WithTransportCredentials(creds credentials.TransportCredentials) DialOption {
 | 
			
		||||
	return func(o *dialOptions) {
 | 
			
		||||
		o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
 | 
			
		||||
		o.copts.TransportCredentials = creds
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithPerRPCCredentials returns a DialOption which sets
 | 
			
		||||
// credentials which will place auth state on each outbound RPC.
 | 
			
		||||
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
 | 
			
		||||
func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption {
 | 
			
		||||
	return func(o *dialOptions) {
 | 
			
		||||
		o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
 | 
			
		||||
		o.copts.PerRPCCredentials = append(o.copts.PerRPCCredentials, creds)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithTimeout returns a DialOption that configures a timeout for dialing a client connection.
 | 
			
		||||
// WithTimeout returns a DialOption that configures a timeout for dialing a ClientConn
 | 
			
		||||
// initially. This is valid if and only if WithBlock() is present.
 | 
			
		||||
func WithTimeout(d time.Duration) DialOption {
 | 
			
		||||
	return func(o *dialOptions) {
 | 
			
		||||
		o.copts.Timeout = d
 | 
			
		||||
		o.timeout = d
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -201,6 +213,7 @@ func WithUserAgent(s string) DialOption {
 | 
			
		|||
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
 | 
			
		||||
	cc := &ClientConn{
 | 
			
		||||
		target: target,
 | 
			
		||||
		conns:  make(map[Address]*addrConn),
 | 
			
		||||
	}
 | 
			
		||||
	for _, opt := range opts {
 | 
			
		||||
		opt(&cc.dopts)
 | 
			
		||||
| 
						 | 
				
			
			@ -214,13 +227,53 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
 | 
			
		|||
		cc.dopts.bs = DefaultBackoffConfig
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if cc.dopts.picker == nil {
 | 
			
		||||
		cc.dopts.picker = &unicastPicker{
 | 
			
		||||
			target: target,
 | 
			
		||||
	cc.balancer = cc.dopts.balancer
 | 
			
		||||
	if cc.balancer == nil {
 | 
			
		||||
		cc.balancer = RoundRobin(nil)
 | 
			
		||||
	}
 | 
			
		||||
	if err := cc.balancer.Start(target); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	var (
 | 
			
		||||
		ok    bool
 | 
			
		||||
		addrs []Address
 | 
			
		||||
	)
 | 
			
		||||
	ch := cc.balancer.Notify()
 | 
			
		||||
	if ch == nil {
 | 
			
		||||
		// There is no name resolver installed.
 | 
			
		||||
		addrs = append(addrs, Address{Addr: target})
 | 
			
		||||
	} else {
 | 
			
		||||
		addrs, ok = <-ch
 | 
			
		||||
		if !ok || len(addrs) == 0 {
 | 
			
		||||
			return nil, errNoAddr
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err := cc.dopts.picker.Init(cc); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	waitC := make(chan error, 1)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for _, a := range addrs {
 | 
			
		||||
			if err := cc.newAddrConn(a, false); err != nil {
 | 
			
		||||
				waitC <- err
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		close(waitC)
 | 
			
		||||
	}()
 | 
			
		||||
	var timeoutCh <-chan time.Time
 | 
			
		||||
	if cc.dopts.timeout > 0 {
 | 
			
		||||
		timeoutCh = time.After(cc.dopts.timeout)
 | 
			
		||||
	}
 | 
			
		||||
	select {
 | 
			
		||||
	case err := <-waitC:
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			cc.Close()
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	case <-timeoutCh:
 | 
			
		||||
		cc.Close()
 | 
			
		||||
		return nil, ErrClientConnTimeout
 | 
			
		||||
	}
 | 
			
		||||
	if ok {
 | 
			
		||||
		go cc.lbWatcher()
 | 
			
		||||
	}
 | 
			
		||||
	colonPos := strings.LastIndex(target, ":")
 | 
			
		||||
	if colonPos == -1 {
 | 
			
		||||
| 
						 | 
				
			
			@ -263,325 +316,358 @@ func (s ConnectivityState) String() string {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ClientConn represents a client connection to an RPC service.
 | 
			
		||||
// ClientConn represents a client connection to an RPC server.
 | 
			
		||||
type ClientConn struct {
 | 
			
		||||
	target    string
 | 
			
		||||
	balancer  Balancer
 | 
			
		||||
	authority string
 | 
			
		||||
	dopts     dialOptions
 | 
			
		||||
 | 
			
		||||
	mu    sync.RWMutex
 | 
			
		||||
	conns map[Address]*addrConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// State returns the connectivity state of cc.
 | 
			
		||||
// This is EXPERIMENTAL API.
 | 
			
		||||
func (cc *ClientConn) State() (ConnectivityState, error) {
 | 
			
		||||
	return cc.dopts.picker.State()
 | 
			
		||||
func (cc *ClientConn) lbWatcher() {
 | 
			
		||||
	for addrs := range cc.balancer.Notify() {
 | 
			
		||||
		var (
 | 
			
		||||
			add []Address   // Addresses need to setup connections.
 | 
			
		||||
			del []*addrConn // Connections need to tear down.
 | 
			
		||||
		)
 | 
			
		||||
		cc.mu.Lock()
 | 
			
		||||
		for _, a := range addrs {
 | 
			
		||||
			if _, ok := cc.conns[a]; !ok {
 | 
			
		||||
				add = append(add, a)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		for k, c := range cc.conns {
 | 
			
		||||
			var keep bool
 | 
			
		||||
			for _, a := range addrs {
 | 
			
		||||
				if k == a {
 | 
			
		||||
					keep = true
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if !keep {
 | 
			
		||||
				del = append(del, c)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		cc.mu.Unlock()
 | 
			
		||||
		for _, a := range add {
 | 
			
		||||
			cc.newAddrConn(a, true)
 | 
			
		||||
		}
 | 
			
		||||
		for _, c := range del {
 | 
			
		||||
			c.tearDown(errConnDrain)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WaitForStateChange blocks until the state changes to something other than the sourceState.
 | 
			
		||||
// It returns the new state or error.
 | 
			
		||||
// This is EXPERIMENTAL API.
 | 
			
		||||
func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
 | 
			
		||||
	return cc.dopts.picker.WaitForStateChange(ctx, sourceState)
 | 
			
		||||
func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
 | 
			
		||||
	ac := &addrConn{
 | 
			
		||||
		cc:           cc,
 | 
			
		||||
		addr:         addr,
 | 
			
		||||
		dopts:        cc.dopts,
 | 
			
		||||
		shutdownChan: make(chan struct{}),
 | 
			
		||||
	}
 | 
			
		||||
	if EnableTracing {
 | 
			
		||||
		ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
 | 
			
		||||
	}
 | 
			
		||||
	if !ac.dopts.insecure {
 | 
			
		||||
		if ac.dopts.copts.TransportCredentials == nil {
 | 
			
		||||
			return errNoTransportSecurity
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		if ac.dopts.copts.TransportCredentials != nil {
 | 
			
		||||
			return errCredentialsConflict
 | 
			
		||||
		}
 | 
			
		||||
		for _, cd := range ac.dopts.copts.PerRPCCredentials {
 | 
			
		||||
			if cd.RequireTransportSecurity() {
 | 
			
		||||
				return errTransportCredentialsMissing
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	// Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called.
 | 
			
		||||
	ac.cc.mu.Lock()
 | 
			
		||||
	if ac.cc.conns == nil {
 | 
			
		||||
		ac.cc.mu.Unlock()
 | 
			
		||||
		return ErrClientConnClosing
 | 
			
		||||
	}
 | 
			
		||||
	stale := ac.cc.conns[ac.addr]
 | 
			
		||||
	ac.cc.conns[ac.addr] = ac
 | 
			
		||||
	ac.cc.mu.Unlock()
 | 
			
		||||
	if stale != nil {
 | 
			
		||||
		// There is an addrConn alive on ac.addr already. This could be due to
 | 
			
		||||
		// i) stale's Close is undergoing;
 | 
			
		||||
		// ii) a buggy Balancer notifies duplicated Addresses.
 | 
			
		||||
		stale.tearDown(errConnDrain)
 | 
			
		||||
	}
 | 
			
		||||
	ac.stateCV = sync.NewCond(&ac.mu)
 | 
			
		||||
	// skipWait may overwrite the decision in ac.dopts.block.
 | 
			
		||||
	if ac.dopts.block && !skipWait {
 | 
			
		||||
		if err := ac.resetTransport(false); err != nil {
 | 
			
		||||
			ac.tearDown(err)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		// Start to monitor the error status of transport.
 | 
			
		||||
		go ac.transportMonitor()
 | 
			
		||||
	} else {
 | 
			
		||||
		// Start a goroutine connecting to the server asynchronously.
 | 
			
		||||
		go func() {
 | 
			
		||||
			if err := ac.resetTransport(false); err != nil {
 | 
			
		||||
				grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
 | 
			
		||||
				ac.tearDown(err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			ac.transportMonitor()
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close starts to tear down the ClientConn.
 | 
			
		||||
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
 | 
			
		||||
	// TODO(zhaoq): Implement fail-fast logic.
 | 
			
		||||
	addr, put, err := cc.balancer.Get(ctx, opts)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
	cc.mu.RLock()
 | 
			
		||||
	if cc.conns == nil {
 | 
			
		||||
		cc.mu.RUnlock()
 | 
			
		||||
		return nil, nil, ErrClientConnClosing
 | 
			
		||||
	}
 | 
			
		||||
	ac, ok := cc.conns[addr]
 | 
			
		||||
	cc.mu.RUnlock()
 | 
			
		||||
	if !ok {
 | 
			
		||||
		if put != nil {
 | 
			
		||||
			put()
 | 
			
		||||
		}
 | 
			
		||||
		return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
 | 
			
		||||
	}
 | 
			
		||||
	t, err := ac.wait(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if put != nil {
 | 
			
		||||
			put()
 | 
			
		||||
		}
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return t, put, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close tears down the ClientConn and all underlying connections.
 | 
			
		||||
func (cc *ClientConn) Close() error {
 | 
			
		||||
	return cc.dopts.picker.Close()
 | 
			
		||||
	cc.mu.Lock()
 | 
			
		||||
	if cc.conns == nil {
 | 
			
		||||
		cc.mu.Unlock()
 | 
			
		||||
		return ErrClientConnClosing
 | 
			
		||||
	}
 | 
			
		||||
	conns := cc.conns
 | 
			
		||||
	cc.conns = nil
 | 
			
		||||
	cc.mu.Unlock()
 | 
			
		||||
	cc.balancer.Close()
 | 
			
		||||
	for _, ac := range conns {
 | 
			
		||||
		ac.tearDown(ErrClientConnClosing)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Conn is a client connection to a single destination.
 | 
			
		||||
type Conn struct {
 | 
			
		||||
	target       string
 | 
			
		||||
// addrConn is a network connection to a given address.
 | 
			
		||||
type addrConn struct {
 | 
			
		||||
	cc           *ClientConn
 | 
			
		||||
	addr         Address
 | 
			
		||||
	dopts        dialOptions
 | 
			
		||||
	resetChan    chan int
 | 
			
		||||
	shutdownChan chan struct{}
 | 
			
		||||
	events       trace.EventLog
 | 
			
		||||
 | 
			
		||||
	mu      sync.Mutex
 | 
			
		||||
	state   ConnectivityState
 | 
			
		||||
	stateCV *sync.Cond
 | 
			
		||||
	down    func(error) // the handler called when a connection is down.
 | 
			
		||||
	// ready is closed and becomes nil when a new transport is up or failed
 | 
			
		||||
	// due to timeout.
 | 
			
		||||
	ready     chan struct{}
 | 
			
		||||
	transport transport.ClientTransport
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewConn creates a Conn.
 | 
			
		||||
func NewConn(cc *ClientConn) (*Conn, error) {
 | 
			
		||||
	if cc.target == "" {
 | 
			
		||||
		return nil, ErrUnspecTarget
 | 
			
		||||
	}
 | 
			
		||||
	c := &Conn{
 | 
			
		||||
		target:       cc.target,
 | 
			
		||||
		dopts:        cc.dopts,
 | 
			
		||||
		resetChan:    make(chan int, 1),
 | 
			
		||||
		shutdownChan: make(chan struct{}),
 | 
			
		||||
	}
 | 
			
		||||
	if EnableTracing {
 | 
			
		||||
		c.events = trace.NewEventLog("grpc.ClientConn", c.target)
 | 
			
		||||
	}
 | 
			
		||||
	if !c.dopts.insecure {
 | 
			
		||||
		var ok bool
 | 
			
		||||
		for _, cd := range c.dopts.copts.AuthOptions {
 | 
			
		||||
			if _, ok = cd.(credentials.TransportAuthenticator); ok {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if !ok {
 | 
			
		||||
			return nil, ErrNoTransportSecurity
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		for _, cd := range c.dopts.copts.AuthOptions {
 | 
			
		||||
			if cd.RequireTransportSecurity() {
 | 
			
		||||
				return nil, ErrCredentialsMisuse
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	c.stateCV = sync.NewCond(&c.mu)
 | 
			
		||||
	if c.dopts.block {
 | 
			
		||||
		if err := c.resetTransport(false); err != nil {
 | 
			
		||||
			c.Close()
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		// Start to monitor the error status of transport.
 | 
			
		||||
		go c.transportMonitor()
 | 
			
		||||
	} else {
 | 
			
		||||
		// Start a goroutine connecting to the server asynchronously.
 | 
			
		||||
		go func() {
 | 
			
		||||
			if err := c.resetTransport(false); err != nil {
 | 
			
		||||
				grpclog.Printf("Failed to dial %s: %v; please retry.", c.target, err)
 | 
			
		||||
				c.Close()
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			c.transportMonitor()
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	return c, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// printf records an event in cc's event log, unless cc has been closed.
 | 
			
		||||
// REQUIRES cc.mu is held.
 | 
			
		||||
func (cc *Conn) printf(format string, a ...interface{}) {
 | 
			
		||||
	if cc.events != nil {
 | 
			
		||||
		cc.events.Printf(format, a...)
 | 
			
		||||
// printf records an event in ac's event log, unless ac has been closed.
 | 
			
		||||
// REQUIRES ac.mu is held.
 | 
			
		||||
func (ac *addrConn) printf(format string, a ...interface{}) {
 | 
			
		||||
	if ac.events != nil {
 | 
			
		||||
		ac.events.Printf(format, a...)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// errorf records an error in cc's event log, unless cc has been closed.
 | 
			
		||||
// REQUIRES cc.mu is held.
 | 
			
		||||
func (cc *Conn) errorf(format string, a ...interface{}) {
 | 
			
		||||
	if cc.events != nil {
 | 
			
		||||
		cc.events.Errorf(format, a...)
 | 
			
		||||
// errorf records an error in ac's event log, unless ac has been closed.
 | 
			
		||||
// REQUIRES ac.mu is held.
 | 
			
		||||
func (ac *addrConn) errorf(format string, a ...interface{}) {
 | 
			
		||||
	if ac.events != nil {
 | 
			
		||||
		ac.events.Errorf(format, a...)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// State returns the connectivity state of the Conn
 | 
			
		||||
func (cc *Conn) State() ConnectivityState {
 | 
			
		||||
	cc.mu.Lock()
 | 
			
		||||
	defer cc.mu.Unlock()
 | 
			
		||||
	return cc.state
 | 
			
		||||
// getState returns the connectivity state of the Conn
 | 
			
		||||
func (ac *addrConn) getState() ConnectivityState {
 | 
			
		||||
	ac.mu.Lock()
 | 
			
		||||
	defer ac.mu.Unlock()
 | 
			
		||||
	return ac.state
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WaitForStateChange blocks until the state changes to something other than the sourceState.
 | 
			
		||||
func (cc *Conn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
 | 
			
		||||
	cc.mu.Lock()
 | 
			
		||||
	defer cc.mu.Unlock()
 | 
			
		||||
	if sourceState != cc.state {
 | 
			
		||||
		return cc.state, nil
 | 
			
		||||
// waitForStateChange blocks until the state changes to something other than the sourceState.
 | 
			
		||||
func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
 | 
			
		||||
	ac.mu.Lock()
 | 
			
		||||
	defer ac.mu.Unlock()
 | 
			
		||||
	if sourceState != ac.state {
 | 
			
		||||
		return ac.state, nil
 | 
			
		||||
	}
 | 
			
		||||
	done := make(chan struct{})
 | 
			
		||||
	var err error
 | 
			
		||||
	go func() {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ctx.Done():
 | 
			
		||||
			cc.mu.Lock()
 | 
			
		||||
			ac.mu.Lock()
 | 
			
		||||
			err = ctx.Err()
 | 
			
		||||
			cc.stateCV.Broadcast()
 | 
			
		||||
			cc.mu.Unlock()
 | 
			
		||||
			ac.stateCV.Broadcast()
 | 
			
		||||
			ac.mu.Unlock()
 | 
			
		||||
		case <-done:
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	defer close(done)
 | 
			
		||||
	for sourceState == cc.state {
 | 
			
		||||
		cc.stateCV.Wait()
 | 
			
		||||
	for sourceState == ac.state {
 | 
			
		||||
		ac.stateCV.Wait()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return cc.state, err
 | 
			
		||||
			return ac.state, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return cc.state, nil
 | 
			
		||||
	return ac.state, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NotifyReset tries to signal the underlying transport needs to be reset due to
 | 
			
		||||
// for example a name resolution change in flight.
 | 
			
		||||
func (cc *Conn) NotifyReset() {
 | 
			
		||||
	select {
 | 
			
		||||
	case cc.resetChan <- 0:
 | 
			
		||||
	default:
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cc *Conn) resetTransport(closeTransport bool) error {
 | 
			
		||||
func (ac *addrConn) resetTransport(closeTransport bool) error {
 | 
			
		||||
	var retries int
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	for {
 | 
			
		||||
		cc.mu.Lock()
 | 
			
		||||
		cc.printf("connecting")
 | 
			
		||||
		if cc.state == Shutdown {
 | 
			
		||||
			// cc.Close() has been invoked.
 | 
			
		||||
			cc.mu.Unlock()
 | 
			
		||||
			return ErrClientConnClosing
 | 
			
		||||
		ac.mu.Lock()
 | 
			
		||||
		ac.printf("connecting")
 | 
			
		||||
		if ac.state == Shutdown {
 | 
			
		||||
			// ac.tearDown(...) has been invoked.
 | 
			
		||||
			ac.mu.Unlock()
 | 
			
		||||
			return errConnClosing
 | 
			
		||||
		}
 | 
			
		||||
		cc.state = Connecting
 | 
			
		||||
		cc.stateCV.Broadcast()
 | 
			
		||||
		cc.mu.Unlock()
 | 
			
		||||
		if closeTransport {
 | 
			
		||||
			cc.transport.Close()
 | 
			
		||||
		if ac.down != nil {
 | 
			
		||||
			ac.down(downErrorf(false, true, "%v", errNetworkIO))
 | 
			
		||||
			ac.down = nil
 | 
			
		||||
		}
 | 
			
		||||
		// Adjust timeout for the current try.
 | 
			
		||||
		copts := cc.dopts.copts
 | 
			
		||||
		if copts.Timeout < 0 {
 | 
			
		||||
			cc.Close()
 | 
			
		||||
			return ErrClientConnTimeout
 | 
			
		||||
		ac.state = Connecting
 | 
			
		||||
		ac.stateCV.Broadcast()
 | 
			
		||||
		t := ac.transport
 | 
			
		||||
		ac.mu.Unlock()
 | 
			
		||||
		if closeTransport && t != nil {
 | 
			
		||||
			t.Close()
 | 
			
		||||
		}
 | 
			
		||||
		if copts.Timeout > 0 {
 | 
			
		||||
			copts.Timeout -= time.Since(start)
 | 
			
		||||
			if copts.Timeout <= 0 {
 | 
			
		||||
				cc.Close()
 | 
			
		||||
				return ErrClientConnTimeout
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		sleepTime := cc.dopts.bs.backoff(retries)
 | 
			
		||||
		timeout := sleepTime
 | 
			
		||||
		if timeout < minConnectTimeout {
 | 
			
		||||
			timeout = minConnectTimeout
 | 
			
		||||
		}
 | 
			
		||||
		if copts.Timeout == 0 || copts.Timeout > timeout {
 | 
			
		||||
			copts.Timeout = timeout
 | 
			
		||||
		sleepTime := ac.dopts.bs.backoff(retries)
 | 
			
		||||
		ac.dopts.copts.Timeout = sleepTime
 | 
			
		||||
		if sleepTime < minConnectTimeout {
 | 
			
		||||
			ac.dopts.copts.Timeout = minConnectTimeout
 | 
			
		||||
		}
 | 
			
		||||
		connectTime := time.Now()
 | 
			
		||||
		addr, err := cc.dopts.picker.PickAddr()
 | 
			
		||||
		var newTransport transport.ClientTransport
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			newTransport, err = transport.NewClientTransport(addr, &copts)
 | 
			
		||||
		}
 | 
			
		||||
		newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			cc.mu.Lock()
 | 
			
		||||
			if cc.state == Shutdown {
 | 
			
		||||
				// cc.Close() has been invoked.
 | 
			
		||||
				cc.mu.Unlock()
 | 
			
		||||
				return ErrClientConnClosing
 | 
			
		||||
			ac.mu.Lock()
 | 
			
		||||
			if ac.state == Shutdown {
 | 
			
		||||
				// ac.tearDown(...) has been invoked.
 | 
			
		||||
				ac.mu.Unlock()
 | 
			
		||||
				return errConnClosing
 | 
			
		||||
			}
 | 
			
		||||
			cc.errorf("transient failure: %v", err)
 | 
			
		||||
			cc.state = TransientFailure
 | 
			
		||||
			cc.stateCV.Broadcast()
 | 
			
		||||
			if cc.ready != nil {
 | 
			
		||||
				close(cc.ready)
 | 
			
		||||
				cc.ready = nil
 | 
			
		||||
			ac.errorf("transient failure: %v", err)
 | 
			
		||||
			ac.state = TransientFailure
 | 
			
		||||
			ac.stateCV.Broadcast()
 | 
			
		||||
			if ac.ready != nil {
 | 
			
		||||
				close(ac.ready)
 | 
			
		||||
				ac.ready = nil
 | 
			
		||||
			}
 | 
			
		||||
			cc.mu.Unlock()
 | 
			
		||||
			ac.mu.Unlock()
 | 
			
		||||
			sleepTime -= time.Since(connectTime)
 | 
			
		||||
			if sleepTime < 0 {
 | 
			
		||||
				sleepTime = 0
 | 
			
		||||
			}
 | 
			
		||||
			// Fail early before falling into sleep.
 | 
			
		||||
			if cc.dopts.copts.Timeout > 0 && cc.dopts.copts.Timeout < sleepTime+time.Since(start) {
 | 
			
		||||
				cc.mu.Lock()
 | 
			
		||||
				cc.errorf("connection timeout")
 | 
			
		||||
				cc.mu.Unlock()
 | 
			
		||||
				cc.Close()
 | 
			
		||||
				return ErrClientConnTimeout
 | 
			
		||||
			}
 | 
			
		||||
			closeTransport = false
 | 
			
		||||
			time.Sleep(sleepTime)
 | 
			
		||||
			select {
 | 
			
		||||
			case <-time.After(sleepTime):
 | 
			
		||||
			case <-ac.shutdownChan:
 | 
			
		||||
			}
 | 
			
		||||
			retries++
 | 
			
		||||
			grpclog.Printf("grpc: Conn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, cc.target)
 | 
			
		||||
			grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		cc.mu.Lock()
 | 
			
		||||
		cc.printf("ready")
 | 
			
		||||
		if cc.state == Shutdown {
 | 
			
		||||
			// cc.Close() has been invoked.
 | 
			
		||||
			cc.mu.Unlock()
 | 
			
		||||
		ac.mu.Lock()
 | 
			
		||||
		ac.printf("ready")
 | 
			
		||||
		if ac.state == Shutdown {
 | 
			
		||||
			// ac.tearDown(...) has been invoked.
 | 
			
		||||
			ac.mu.Unlock()
 | 
			
		||||
			newTransport.Close()
 | 
			
		||||
			return ErrClientConnClosing
 | 
			
		||||
			return errConnClosing
 | 
			
		||||
		}
 | 
			
		||||
		cc.state = Ready
 | 
			
		||||
		cc.stateCV.Broadcast()
 | 
			
		||||
		cc.transport = newTransport
 | 
			
		||||
		if cc.ready != nil {
 | 
			
		||||
			close(cc.ready)
 | 
			
		||||
			cc.ready = nil
 | 
			
		||||
		ac.state = Ready
 | 
			
		||||
		ac.stateCV.Broadcast()
 | 
			
		||||
		ac.transport = newTransport
 | 
			
		||||
		if ac.ready != nil {
 | 
			
		||||
			close(ac.ready)
 | 
			
		||||
			ac.ready = nil
 | 
			
		||||
		}
 | 
			
		||||
		cc.mu.Unlock()
 | 
			
		||||
		ac.down = ac.cc.balancer.Up(ac.addr)
 | 
			
		||||
		ac.mu.Unlock()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cc *Conn) reconnect() bool {
 | 
			
		||||
	cc.mu.Lock()
 | 
			
		||||
	if cc.state == Shutdown {
 | 
			
		||||
		// cc.Close() has been invoked.
 | 
			
		||||
		cc.mu.Unlock()
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	cc.state = TransientFailure
 | 
			
		||||
	cc.stateCV.Broadcast()
 | 
			
		||||
	cc.mu.Unlock()
 | 
			
		||||
	if err := cc.resetTransport(true); err != nil {
 | 
			
		||||
		// The ClientConn is closing.
 | 
			
		||||
		cc.mu.Lock()
 | 
			
		||||
		cc.printf("transport exiting: %v", err)
 | 
			
		||||
		cc.mu.Unlock()
 | 
			
		||||
		grpclog.Printf("grpc: Conn.transportMonitor exits due to: %v", err)
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Run in a goroutine to track the error in transport and create the
 | 
			
		||||
// new transport if an error happens. It returns when the channel is closing.
 | 
			
		||||
func (cc *Conn) transportMonitor() {
 | 
			
		||||
func (ac *addrConn) transportMonitor() {
 | 
			
		||||
	for {
 | 
			
		||||
		ac.mu.Lock()
 | 
			
		||||
		t := ac.transport
 | 
			
		||||
		ac.mu.Unlock()
 | 
			
		||||
		select {
 | 
			
		||||
		// shutdownChan is needed to detect the teardown when
 | 
			
		||||
		// the ClientConn is idle (i.e., no RPC in flight).
 | 
			
		||||
		case <-cc.shutdownChan:
 | 
			
		||||
		// the addrConn is idle (i.e., no RPC in flight).
 | 
			
		||||
		case <-ac.shutdownChan:
 | 
			
		||||
			return
 | 
			
		||||
		case <-cc.resetChan:
 | 
			
		||||
			if !cc.reconnect() {
 | 
			
		||||
		case <-t.Error():
 | 
			
		||||
			ac.mu.Lock()
 | 
			
		||||
			if ac.state == Shutdown {
 | 
			
		||||
				// ac.tearDown(...) has been invoked.
 | 
			
		||||
				ac.mu.Unlock()
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		case <-cc.transport.Error():
 | 
			
		||||
			if !cc.reconnect() {
 | 
			
		||||
			ac.state = TransientFailure
 | 
			
		||||
			ac.stateCV.Broadcast()
 | 
			
		||||
			ac.mu.Unlock()
 | 
			
		||||
			if err := ac.resetTransport(true); err != nil {
 | 
			
		||||
				ac.mu.Lock()
 | 
			
		||||
				ac.printf("transport exiting: %v", err)
 | 
			
		||||
				ac.mu.Unlock()
 | 
			
		||||
				grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			// Tries to drain reset signal if there is any since it is out-dated.
 | 
			
		||||
			select {
 | 
			
		||||
			case <-cc.resetChan:
 | 
			
		||||
			default:
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Wait blocks until i) the new transport is up or ii) ctx is done or iii) cc is closed.
 | 
			
		||||
func (cc *Conn) Wait(ctx context.Context) (transport.ClientTransport, error) {
 | 
			
		||||
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed.
 | 
			
		||||
func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) {
 | 
			
		||||
	for {
 | 
			
		||||
		cc.mu.Lock()
 | 
			
		||||
		ac.mu.Lock()
 | 
			
		||||
		switch {
 | 
			
		||||
		case cc.state == Shutdown:
 | 
			
		||||
			cc.mu.Unlock()
 | 
			
		||||
			return nil, ErrClientConnClosing
 | 
			
		||||
		case cc.state == Ready:
 | 
			
		||||
			ct := cc.transport
 | 
			
		||||
			cc.mu.Unlock()
 | 
			
		||||
		case ac.state == Shutdown:
 | 
			
		||||
			ac.mu.Unlock()
 | 
			
		||||
			return nil, errConnClosing
 | 
			
		||||
		case ac.state == Ready:
 | 
			
		||||
			ct := ac.transport
 | 
			
		||||
			ac.mu.Unlock()
 | 
			
		||||
			return ct, nil
 | 
			
		||||
		default:
 | 
			
		||||
			ready := cc.ready
 | 
			
		||||
			ready := ac.ready
 | 
			
		||||
			if ready == nil {
 | 
			
		||||
				ready = make(chan struct{})
 | 
			
		||||
				cc.ready = ready
 | 
			
		||||
				ac.ready = ready
 | 
			
		||||
			}
 | 
			
		||||
			cc.mu.Unlock()
 | 
			
		||||
			ac.mu.Unlock()
 | 
			
		||||
			select {
 | 
			
		||||
			case <-ctx.Done():
 | 
			
		||||
				return nil, transport.ContextErr(ctx.Err())
 | 
			
		||||
| 
						 | 
				
			
			@ -592,32 +678,46 @@ func (cc *Conn) Wait(ctx context.Context) (transport.ClientTransport, error) {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close starts to tear down the Conn. Returns ErrClientConnClosing if
 | 
			
		||||
// it has been closed (mostly due to dial time-out).
 | 
			
		||||
// tearDown starts to tear down the addrConn.
 | 
			
		||||
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
 | 
			
		||||
// some edge cases (e.g., the caller opens and closes many ClientConn's in a
 | 
			
		||||
// some edge cases (e.g., the caller opens and closes many addrConn's in a
 | 
			
		||||
// tight loop.
 | 
			
		||||
func (cc *Conn) Close() error {
 | 
			
		||||
	cc.mu.Lock()
 | 
			
		||||
	defer cc.mu.Unlock()
 | 
			
		||||
	if cc.state == Shutdown {
 | 
			
		||||
		return ErrClientConnClosing
 | 
			
		||||
func (ac *addrConn) tearDown(err error) {
 | 
			
		||||
	ac.mu.Lock()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		ac.mu.Unlock()
 | 
			
		||||
		ac.cc.mu.Lock()
 | 
			
		||||
		if ac.cc.conns != nil {
 | 
			
		||||
			delete(ac.cc.conns, ac.addr)
 | 
			
		||||
		}
 | 
			
		||||
		ac.cc.mu.Unlock()
 | 
			
		||||
	}()
 | 
			
		||||
	if ac.state == Shutdown {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	cc.state = Shutdown
 | 
			
		||||
	cc.stateCV.Broadcast()
 | 
			
		||||
	if cc.events != nil {
 | 
			
		||||
		cc.events.Finish()
 | 
			
		||||
		cc.events = nil
 | 
			
		||||
	ac.state = Shutdown
 | 
			
		||||
	if ac.down != nil {
 | 
			
		||||
		ac.down(downErrorf(false, false, "%v", err))
 | 
			
		||||
		ac.down = nil
 | 
			
		||||
	}
 | 
			
		||||
	if cc.ready != nil {
 | 
			
		||||
		close(cc.ready)
 | 
			
		||||
		cc.ready = nil
 | 
			
		||||
	ac.stateCV.Broadcast()
 | 
			
		||||
	if ac.events != nil {
 | 
			
		||||
		ac.events.Finish()
 | 
			
		||||
		ac.events = nil
 | 
			
		||||
	}
 | 
			
		||||
	if cc.transport != nil {
 | 
			
		||||
		cc.transport.Close()
 | 
			
		||||
	if ac.ready != nil {
 | 
			
		||||
		close(ac.ready)
 | 
			
		||||
		ac.ready = nil
 | 
			
		||||
	}
 | 
			
		||||
	if cc.shutdownChan != nil {
 | 
			
		||||
		close(cc.shutdownChan)
 | 
			
		||||
	if ac.transport != nil {
 | 
			
		||||
		if err == errConnDrain {
 | 
			
		||||
			ac.transport.GracefulClose()
 | 
			
		||||
		} else {
 | 
			
		||||
			ac.transport.Close()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
	if ac.shutdownChan != nil {
 | 
			
		||||
		close(ac.shutdownChan)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -54,9 +54,9 @@ var (
 | 
			
		|||
	alpnProtoStr = []string{"h2"}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Credentials defines the common interface all supported credentials must
 | 
			
		||||
// implement.
 | 
			
		||||
type Credentials interface {
 | 
			
		||||
// PerRPCCredentials defines the common interface for the credentials which need to
 | 
			
		||||
// attach security information to every RPC (e.g., oauth2).
 | 
			
		||||
type PerRPCCredentials interface {
 | 
			
		||||
	// GetRequestMetadata gets the current request metadata, refreshing
 | 
			
		||||
	// tokens if required. This should be called by the transport layer on
 | 
			
		||||
	// each request, and the data should be populated in headers or other
 | 
			
		||||
| 
						 | 
				
			
			@ -87,9 +87,9 @@ type AuthInfo interface {
 | 
			
		|||
	AuthType() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TransportAuthenticator defines the common interface for all the live gRPC wire
 | 
			
		||||
// TransportCredentials defines the common interface for all the live gRPC wire
 | 
			
		||||
// protocols and supported transport security protocols (e.g., TLS, SSL).
 | 
			
		||||
type TransportAuthenticator interface {
 | 
			
		||||
type TransportCredentials interface {
 | 
			
		||||
	// ClientHandshake does the authentication handshake specified by the corresponding
 | 
			
		||||
	// authentication protocol on rawConn for clients. It returns the authenticated
 | 
			
		||||
	// connection and the corresponding auth information about the connection.
 | 
			
		||||
| 
						 | 
				
			
			@ -98,9 +98,8 @@ type TransportAuthenticator interface {
 | 
			
		|||
	// the authenticated connection and the corresponding auth information about
 | 
			
		||||
	// the connection.
 | 
			
		||||
	ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
 | 
			
		||||
	// Info provides the ProtocolInfo of this TransportAuthenticator.
 | 
			
		||||
	// Info provides the ProtocolInfo of this TransportCredentials.
 | 
			
		||||
	Info() ProtocolInfo
 | 
			
		||||
	Credentials
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TLSInfo contains the auth information for a TLS authenticated connection.
 | 
			
		||||
| 
						 | 
				
			
			@ -109,6 +108,7 @@ type TLSInfo struct {
 | 
			
		|||
	State tls.ConnectionState
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AuthType returns the type of TLSInfo as a string.
 | 
			
		||||
func (t TLSInfo) AuthType() string {
 | 
			
		||||
	return "tls"
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -185,20 +185,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
 | 
			
		|||
	return conn, TLSInfo{conn.ConnectionState()}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewTLS uses c to construct a TransportAuthenticator based on TLS.
 | 
			
		||||
func NewTLS(c *tls.Config) TransportAuthenticator {
 | 
			
		||||
// NewTLS uses c to construct a TransportCredentials based on TLS.
 | 
			
		||||
func NewTLS(c *tls.Config) TransportCredentials {
 | 
			
		||||
	tc := &tlsCreds{*c}
 | 
			
		||||
	tc.config.NextProtos = alpnProtoStr
 | 
			
		||||
	return tc
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewClientTLSFromCert constructs a TLS from the input certificate for client.
 | 
			
		||||
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator {
 | 
			
		||||
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials {
 | 
			
		||||
	return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
 | 
			
		||||
func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, error) {
 | 
			
		||||
func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) {
 | 
			
		||||
	b, err := ioutil.ReadFile(certFile)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
| 
						 | 
				
			
			@ -211,13 +211,13 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator,
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
// NewServerTLSFromCert constructs a TLS from the input certificate for server.
 | 
			
		||||
func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator {
 | 
			
		||||
func NewServerTLSFromCert(cert *tls.Certificate) TransportCredentials {
 | 
			
		||||
	return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewServerTLSFromFile constructs a TLS from the input certificate file and key
 | 
			
		||||
// file for server.
 | 
			
		||||
func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, error) {
 | 
			
		||||
func NewServerTLSFromFile(certFile, keyFile string) (TransportCredentials, error) {
 | 
			
		||||
	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,7 +66,8 @@ type Resolver interface {
 | 
			
		|||
// Watcher watches for the updates on the specified target.
 | 
			
		||||
type Watcher interface {
 | 
			
		||||
	// Next blocks until an update or error happens. It may return one or more
 | 
			
		||||
	// updates. The first call should get the full set of the results.
 | 
			
		||||
	// updates. The first call should get the full set of the results. It should
 | 
			
		||||
	// return an error if and only if Watcher cannot recover.
 | 
			
		||||
	Next() ([]*Update, error)
 | 
			
		||||
	// Close closes the Watcher.
 | 
			
		||||
	Close()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,243 +0,0 @@
 | 
			
		|||
/*
 | 
			
		||||
 *
 | 
			
		||||
 * Copyright 2014, Google Inc.
 | 
			
		||||
 * All rights reserved.
 | 
			
		||||
 *
 | 
			
		||||
 * Redistribution and use in source and binary forms, with or without
 | 
			
		||||
 * modification, are permitted provided that the following conditions are
 | 
			
		||||
 * met:
 | 
			
		||||
 *
 | 
			
		||||
 *     * Redistributions of source code must retain the above copyright
 | 
			
		||||
 * notice, this list of conditions and the following disclaimer.
 | 
			
		||||
 *     * Redistributions in binary form must reproduce the above
 | 
			
		||||
 * copyright notice, this list of conditions and the following disclaimer
 | 
			
		||||
 * in the documentation and/or other materials provided with the
 | 
			
		||||
 * distribution.
 | 
			
		||||
 *     * Neither the name of Google Inc. nor the names of its
 | 
			
		||||
 * contributors may be used to endorse or promote products derived from
 | 
			
		||||
 * this software without specific prior written permission.
 | 
			
		||||
 *
 | 
			
		||||
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 | 
			
		||||
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 | 
			
		||||
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 | 
			
		||||
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 | 
			
		||||
 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 | 
			
		||||
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 | 
			
		||||
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 | 
			
		||||
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 | 
			
		||||
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 | 
			
		||||
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | 
			
		||||
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package grpc
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"container/list"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
	"google.golang.org/grpc/grpclog"
 | 
			
		||||
	"google.golang.org/grpc/naming"
 | 
			
		||||
	"google.golang.org/grpc/transport"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Picker picks a Conn for RPC requests.
 | 
			
		||||
// This is EXPERIMENTAL and please do not implement your own Picker for now.
 | 
			
		||||
type Picker interface {
 | 
			
		||||
	// Init does initial processing for the Picker, e.g., initiate some connections.
 | 
			
		||||
	Init(cc *ClientConn) error
 | 
			
		||||
	// Pick blocks until either a transport.ClientTransport is ready for the upcoming RPC
 | 
			
		||||
	// or some error happens.
 | 
			
		||||
	Pick(ctx context.Context) (transport.ClientTransport, error)
 | 
			
		||||
	// PickAddr picks a peer address for connecting. This will be called repeated for
 | 
			
		||||
	// connecting/reconnecting.
 | 
			
		||||
	PickAddr() (string, error)
 | 
			
		||||
	// State returns the connectivity state of the underlying connections.
 | 
			
		||||
	State() (ConnectivityState, error)
 | 
			
		||||
	// WaitForStateChange blocks until the state changes to something other than
 | 
			
		||||
	// the sourceState. It returns the new state or error.
 | 
			
		||||
	WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error)
 | 
			
		||||
	// Close closes all the Conn's owned by this Picker.
 | 
			
		||||
	Close() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// unicastPicker is the default Picker which is used when there is no custom Picker
 | 
			
		||||
// specified by users. It always picks the same Conn.
 | 
			
		||||
type unicastPicker struct {
 | 
			
		||||
	target string
 | 
			
		||||
	conn   *Conn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastPicker) Init(cc *ClientConn) error {
 | 
			
		||||
	c, err := NewConn(cc)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	p.conn = c
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastPicker) Pick(ctx context.Context) (transport.ClientTransport, error) {
 | 
			
		||||
	return p.conn.Wait(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastPicker) PickAddr() (string, error) {
 | 
			
		||||
	return p.target, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastPicker) State() (ConnectivityState, error) {
 | 
			
		||||
	return p.conn.State(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
 | 
			
		||||
	return p.conn.WaitForStateChange(ctx, sourceState)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastPicker) Close() error {
 | 
			
		||||
	if p.conn != nil {
 | 
			
		||||
		return p.conn.Close()
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// unicastNamingPicker picks an address from a name resolver to set up the connection.
 | 
			
		||||
type unicastNamingPicker struct {
 | 
			
		||||
	cc       *ClientConn
 | 
			
		||||
	resolver naming.Resolver
 | 
			
		||||
	watcher  naming.Watcher
 | 
			
		||||
	mu       sync.Mutex
 | 
			
		||||
	// The list of the addresses are obtained from watcher.
 | 
			
		||||
	addrs *list.List
 | 
			
		||||
	// It tracks the current picked addr by PickAddr(). The next PickAddr may
 | 
			
		||||
	// push it forward on addrs.
 | 
			
		||||
	pickedAddr *list.Element
 | 
			
		||||
	conn       *Conn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewUnicastNamingPicker creates a Picker to pick addresses from a name resolver
 | 
			
		||||
// to connect.
 | 
			
		||||
func NewUnicastNamingPicker(r naming.Resolver) Picker {
 | 
			
		||||
	return &unicastNamingPicker{
 | 
			
		||||
		resolver: r,
 | 
			
		||||
		addrs:    list.New(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type addrInfo struct {
 | 
			
		||||
	addr string
 | 
			
		||||
	// Set to true if this addrInfo needs to be deleted in the next PickAddrr() call.
 | 
			
		||||
	deleting bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// processUpdates calls Watcher.Next() once and processes the obtained updates.
 | 
			
		||||
func (p *unicastNamingPicker) processUpdates() error {
 | 
			
		||||
	updates, err := p.watcher.Next()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	for _, update := range updates {
 | 
			
		||||
		switch update.Op {
 | 
			
		||||
		case naming.Add:
 | 
			
		||||
			p.mu.Lock()
 | 
			
		||||
			p.addrs.PushBack(&addrInfo{
 | 
			
		||||
				addr: update.Addr,
 | 
			
		||||
			})
 | 
			
		||||
			p.mu.Unlock()
 | 
			
		||||
			// Initial connection setup
 | 
			
		||||
			if p.conn == nil {
 | 
			
		||||
				conn, err := NewConn(p.cc)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return err
 | 
			
		||||
				}
 | 
			
		||||
				p.conn = conn
 | 
			
		||||
			}
 | 
			
		||||
		case naming.Delete:
 | 
			
		||||
			p.mu.Lock()
 | 
			
		||||
			for e := p.addrs.Front(); e != nil; e = e.Next() {
 | 
			
		||||
				if update.Addr == e.Value.(*addrInfo).addr {
 | 
			
		||||
					if e == p.pickedAddr {
 | 
			
		||||
						// Do not remove the element now if it is the current picked
 | 
			
		||||
						// one. We leave the deletion to the next PickAddr() call.
 | 
			
		||||
						e.Value.(*addrInfo).deleting = true
 | 
			
		||||
						// Notify Conn to close it. All the live RPCs on this connection
 | 
			
		||||
						// will be aborted.
 | 
			
		||||
						p.conn.NotifyReset()
 | 
			
		||||
					} else {
 | 
			
		||||
						p.addrs.Remove(e)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			p.mu.Unlock()
 | 
			
		||||
		default:
 | 
			
		||||
			grpclog.Println("Unknown update.Op ", update.Op)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// monitor runs in a standalone goroutine to keep watching name resolution updates until the watcher
 | 
			
		||||
// is closed.
 | 
			
		||||
func (p *unicastNamingPicker) monitor() {
 | 
			
		||||
	for {
 | 
			
		||||
		if err := p.processUpdates(); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastNamingPicker) Init(cc *ClientConn) error {
 | 
			
		||||
	w, err := p.resolver.Resolve(cc.target)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	p.watcher = w
 | 
			
		||||
	p.cc = cc
 | 
			
		||||
	// Get the initial name resolution.
 | 
			
		||||
	if err := p.processUpdates(); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	go p.monitor()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastNamingPicker) Pick(ctx context.Context) (transport.ClientTransport, error) {
 | 
			
		||||
	return p.conn.Wait(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastNamingPicker) PickAddr() (string, error) {
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	defer p.mu.Unlock()
 | 
			
		||||
	if p.pickedAddr == nil {
 | 
			
		||||
		p.pickedAddr = p.addrs.Front()
 | 
			
		||||
	} else {
 | 
			
		||||
		pa := p.pickedAddr
 | 
			
		||||
		p.pickedAddr = pa.Next()
 | 
			
		||||
		if pa.Value.(*addrInfo).deleting {
 | 
			
		||||
			p.addrs.Remove(pa)
 | 
			
		||||
		}
 | 
			
		||||
		if p.pickedAddr == nil {
 | 
			
		||||
			p.pickedAddr = p.addrs.Front()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if p.pickedAddr == nil {
 | 
			
		||||
		return "", fmt.Errorf("there is no address available to pick")
 | 
			
		||||
	}
 | 
			
		||||
	return p.pickedAddr.Value.(*addrInfo).addr, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastNamingPicker) State() (ConnectivityState, error) {
 | 
			
		||||
	return 0, fmt.Errorf("State() is not supported for unicastNamingPicker")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastNamingPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
 | 
			
		||||
	return 0, fmt.Errorf("WaitForStateChange is not supported for unicastNamingPciker")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *unicastNamingPicker) Close() error {
 | 
			
		||||
	p.watcher.Close()
 | 
			
		||||
	p.conn.Close()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -61,7 +61,7 @@ type Codec interface {
 | 
			
		|||
	String() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// protoCodec is a Codec implemetation with protobuf. It is the default codec for gRPC.
 | 
			
		||||
// protoCodec is a Codec implementation with protobuf. It is the default codec for gRPC.
 | 
			
		||||
type protoCodec struct{}
 | 
			
		||||
 | 
			
		||||
func (protoCodec) Marshal(v interface{}) ([]byte, error) {
 | 
			
		||||
| 
						 | 
				
			
			@ -187,7 +187,7 @@ const (
 | 
			
		|||
	compressionMade
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// parser reads complelete gRPC messages from the underlying reader.
 | 
			
		||||
// parser reads complete gRPC messages from the underlying reader.
 | 
			
		||||
type parser struct {
 | 
			
		||||
	// r is the underlying reader.
 | 
			
		||||
	// See the comment on recvMsg for the permissible
 | 
			
		||||
| 
						 | 
				
			
			@ -284,14 +284,11 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
 | 
			
		|||
	switch pf {
 | 
			
		||||
	case compressionNone:
 | 
			
		||||
	case compressionMade:
 | 
			
		||||
		if recvCompress == "" {
 | 
			
		||||
			return transport.StreamErrorf(codes.InvalidArgument, "grpc: invalid grpc-encoding %q with compression enabled", recvCompress)
 | 
			
		||||
		}
 | 
			
		||||
		if dc == nil || recvCompress != dc.Type() {
 | 
			
		||||
			return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
 | 
			
		||||
			return transport.StreamErrorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
 | 
			
		||||
		return transport.StreamErrorf(codes.Internal, "grpc: received unexpected payload format %d", pf)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -73,6 +73,7 @@ type ServiceDesc struct {
 | 
			
		|||
	HandlerType interface{}
 | 
			
		||||
	Methods     []MethodDesc
 | 
			
		||||
	Streams     []StreamDesc
 | 
			
		||||
	Metadata    interface{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// service consists of the information of the server serving this service and
 | 
			
		||||
| 
						 | 
				
			
			@ -95,10 +96,12 @@ type Server struct {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
type options struct {
 | 
			
		||||
	creds                credentials.Credentials
 | 
			
		||||
	creds                credentials.TransportCredentials
 | 
			
		||||
	codec                Codec
 | 
			
		||||
	cp                   Compressor
 | 
			
		||||
	dc                   Decompressor
 | 
			
		||||
	unaryInt             UnaryServerInterceptor
 | 
			
		||||
	streamInt            StreamServerInterceptor
 | 
			
		||||
	maxConcurrentStreams uint32
 | 
			
		||||
	useHandlerImpl       bool // use http.Handler-based server
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -113,12 +116,14 @@ func CustomCodec(codec Codec) ServerOption {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
 | 
			
		||||
func RPCCompressor(cp Compressor) ServerOption {
 | 
			
		||||
	return func(o *options) {
 | 
			
		||||
		o.cp = cp
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
 | 
			
		||||
func RPCDecompressor(dc Decompressor) ServerOption {
 | 
			
		||||
	return func(o *options) {
 | 
			
		||||
		o.dc = dc
 | 
			
		||||
| 
						 | 
				
			
			@ -134,12 +139,35 @@ func MaxConcurrentStreams(n uint32) ServerOption {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
// Creds returns a ServerOption that sets credentials for server connections.
 | 
			
		||||
func Creds(c credentials.Credentials) ServerOption {
 | 
			
		||||
func Creds(c credentials.TransportCredentials) ServerOption {
 | 
			
		||||
	return func(o *options) {
 | 
			
		||||
		o.creds = c
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the
 | 
			
		||||
// server. Only one unary interceptor can be installed. The construction of multiple
 | 
			
		||||
// interceptors (e.g., chaining) can be implemented at the caller.
 | 
			
		||||
func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
 | 
			
		||||
	return func(o *options) {
 | 
			
		||||
		if o.unaryInt != nil {
 | 
			
		||||
			panic("The unary server interceptor has been set.")
 | 
			
		||||
		}
 | 
			
		||||
		o.unaryInt = i
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the
 | 
			
		||||
// server. Only one stream interceptor can be installed.
 | 
			
		||||
func StreamInterceptor(i StreamServerInterceptor) ServerOption {
 | 
			
		||||
	return func(o *options) {
 | 
			
		||||
		if o.streamInt != nil {
 | 
			
		||||
			panic("The stream server interceptor has been set.")
 | 
			
		||||
		}
 | 
			
		||||
		o.streamInt = i
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewServer creates a gRPC server which has no service registered and has not
 | 
			
		||||
// started to accept requests yet.
 | 
			
		||||
func NewServer(opt ...ServerOption) *Server {
 | 
			
		||||
| 
						 | 
				
			
			@ -222,22 +250,23 @@ var (
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
 | 
			
		||||
	creds, ok := s.opts.creds.(credentials.TransportAuthenticator)
 | 
			
		||||
	if !ok {
 | 
			
		||||
	if s.opts.creds == nil {
 | 
			
		||||
		return rawConn, nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	return creds.ServerHandshake(rawConn)
 | 
			
		||||
	return s.opts.creds.ServerHandshake(rawConn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Serve accepts incoming connections on the listener lis, creating a new
 | 
			
		||||
// ServerTransport and service goroutine for each. The service goroutines
 | 
			
		||||
// read gRPC requests and then call the registered handlers to reply to them.
 | 
			
		||||
// Service returns when lis.Accept fails.
 | 
			
		||||
// Service returns when lis.Accept fails. lis will be closed when
 | 
			
		||||
// this method returns.
 | 
			
		||||
func (s *Server) Serve(lis net.Listener) error {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	s.printf("serving")
 | 
			
		||||
	if s.lis == nil {
 | 
			
		||||
		s.mu.Unlock()
 | 
			
		||||
		lis.Close()
 | 
			
		||||
		return ErrServerStopped
 | 
			
		||||
	}
 | 
			
		||||
	s.lis[lis] = true
 | 
			
		||||
| 
						 | 
				
			
			@ -435,6 +464,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
 | 
			
		|||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	if s.opts.cp != nil {
 | 
			
		||||
		// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
 | 
			
		||||
		stream.SetSendCompress(s.opts.cp.Type())
 | 
			
		||||
	}
 | 
			
		||||
	p := &parser{r: stream}
 | 
			
		||||
	for {
 | 
			
		||||
		pf, req, err := p.recvMsg()
 | 
			
		||||
| 
						 | 
				
			
			@ -494,7 +527,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
 | 
			
		|||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		reply, appErr := md.Handler(srv.server, stream.Context(), df, nil)
 | 
			
		||||
		reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
 | 
			
		||||
		if appErr != nil {
 | 
			
		||||
			if err, ok := appErr.(rpcError); ok {
 | 
			
		||||
				statusCode = err.code
 | 
			
		||||
| 
						 | 
				
			
			@ -520,9 +553,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
 | 
			
		|||
			Last:  true,
 | 
			
		||||
			Delay: false,
 | 
			
		||||
		}
 | 
			
		||||
		if s.opts.cp != nil {
 | 
			
		||||
			stream.SetSendCompress(s.opts.cp.Type())
 | 
			
		||||
		}
 | 
			
		||||
		if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
 | 
			
		||||
			switch err := err.(type) {
 | 
			
		||||
			case transport.ConnectionError:
 | 
			
		||||
| 
						 | 
				
			
			@ -572,7 +602,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
 | 
			
		|||
			ss.mu.Unlock()
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	if appErr := sd.Handler(srv.server, ss); appErr != nil {
 | 
			
		||||
	var appErr error
 | 
			
		||||
	if s.opts.streamInt == nil {
 | 
			
		||||
		appErr = sd.Handler(srv.server, ss)
 | 
			
		||||
	} else {
 | 
			
		||||
		info := &StreamServerInfo{
 | 
			
		||||
			FullMethod:     stream.Method(),
 | 
			
		||||
			IsClientStream: sd.ClientStreams,
 | 
			
		||||
			IsServerStream: sd.ServerStreams,
 | 
			
		||||
		}
 | 
			
		||||
		appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
 | 
			
		||||
	}
 | 
			
		||||
	if appErr != nil {
 | 
			
		||||
		if err, ok := appErr.(rpcError); ok {
 | 
			
		||||
			ss.statusCode = err.code
 | 
			
		||||
			ss.statusDesc = err.desc
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -79,9 +79,9 @@ type Stream interface {
 | 
			
		|||
	RecvMsg(m interface{}) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ClientStream defines the interface a client stream has to satify.
 | 
			
		||||
// ClientStream defines the interface a client stream has to satisfy.
 | 
			
		||||
type ClientStream interface {
 | 
			
		||||
	// Header returns the header metedata received from the server if there
 | 
			
		||||
	// Header returns the header metadata received from the server if there
 | 
			
		||||
	// is any. It blocks if the metadata is not ready to read.
 | 
			
		||||
	Header() (metadata.MD, error)
 | 
			
		||||
	// Trailer returns the trailer metadata from the server. It must be called
 | 
			
		||||
| 
						 | 
				
			
			@ -103,12 +103,16 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
 | 
			
		|||
	var (
 | 
			
		||||
		t   transport.ClientTransport
 | 
			
		||||
		err error
 | 
			
		||||
		put func()
 | 
			
		||||
	)
 | 
			
		||||
	t, err = cc.dopts.picker.Pick(ctx)
 | 
			
		||||
	// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
 | 
			
		||||
	gopts := BalancerGetOptions{
 | 
			
		||||
		BlockingWait: false,
 | 
			
		||||
	}
 | 
			
		||||
	t, put, err = cc.getTransport(ctx, gopts)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, toRPCErr(err)
 | 
			
		||||
	}
 | 
			
		||||
	// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
 | 
			
		||||
	callHdr := &transport.CallHdr{
 | 
			
		||||
		Host:   cc.authority,
 | 
			
		||||
		Method: method,
 | 
			
		||||
| 
						 | 
				
			
			@ -119,6 +123,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
 | 
			
		|||
	}
 | 
			
		||||
	cs := &clientStream{
 | 
			
		||||
		desc:    desc,
 | 
			
		||||
		put:     put,
 | 
			
		||||
		codec:   cc.dopts.codec,
 | 
			
		||||
		cp:      cc.dopts.cp,
 | 
			
		||||
		dc:      cc.dopts.dc,
 | 
			
		||||
| 
						 | 
				
			
			@ -174,6 +179,7 @@ type clientStream struct {
 | 
			
		|||
	tracing bool // set to EnableTracing when the clientStream is created.
 | 
			
		||||
 | 
			
		||||
	mu     sync.Mutex
 | 
			
		||||
	put    func()
 | 
			
		||||
	closed bool
 | 
			
		||||
	// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
 | 
			
		||||
	// and is set to nil when the clientStream's finish method is called.
 | 
			
		||||
| 
						 | 
				
			
			@ -311,6 +317,10 @@ func (cs *clientStream) finish(err error) {
 | 
			
		|||
	}
 | 
			
		||||
	cs.mu.Lock()
 | 
			
		||||
	defer cs.mu.Unlock()
 | 
			
		||||
	if cs.put != nil {
 | 
			
		||||
		cs.put()
 | 
			
		||||
		cs.put = nil
 | 
			
		||||
	}
 | 
			
		||||
	if cs.trInfo.tr != nil {
 | 
			
		||||
		if err == nil || err == io.EOF {
 | 
			
		||||
			cs.trInfo.tr.LazyPrintf("RPC: [OK]")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -101,9 +101,8 @@ type payload struct {
 | 
			
		|||
func (p payload) String() string {
 | 
			
		||||
	if p.sent {
 | 
			
		||||
		return fmt.Sprintf("sent: %v", p.msg)
 | 
			
		||||
	} else {
 | 
			
		||||
		return fmt.Sprintf("recv: %v", p.msg)
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("recv: %v", p.msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type fmtStringer struct {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -65,7 +65,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
 | 
			
		|||
	if r.Method != "POST" {
 | 
			
		||||
		return nil, errors.New("invalid gRPC request method")
 | 
			
		||||
	}
 | 
			
		||||
	if !strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
 | 
			
		||||
	if !validContentType(r.Header.Get("Content-Type")) {
 | 
			
		||||
		return nil, errors.New("invalid gRPC request content-type")
 | 
			
		||||
	}
 | 
			
		||||
	if _, ok := w.(http.Flusher); !ok {
 | 
			
		||||
| 
						 | 
				
			
			@ -92,9 +92,12 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	var metakv []string
 | 
			
		||||
	if r.Host != "" {
 | 
			
		||||
		metakv = append(metakv, ":authority", r.Host)
 | 
			
		||||
	}
 | 
			
		||||
	for k, vv := range r.Header {
 | 
			
		||||
		k = strings.ToLower(k)
 | 
			
		||||
		if isReservedHeader(k) {
 | 
			
		||||
		if isReservedHeader(k) && !isWhitelistedPseudoHeader(k) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		for _, v := range vv {
 | 
			
		||||
| 
						 | 
				
			
			@ -108,7 +111,6 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
 | 
			
		|||
				}
 | 
			
		||||
			}
 | 
			
		||||
			metakv = append(metakv, k, v)
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	st.headerMD = metadata.Pairs(metakv...)
 | 
			
		||||
| 
						 | 
				
			
			@ -196,6 +198,10 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
 | 
			
		|||
		}
 | 
			
		||||
		if md := s.Trailer(); len(md) > 0 {
 | 
			
		||||
			for k, vv := range md {
 | 
			
		||||
				// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
 | 
			
		||||
				if isReservedHeader(k) {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				for _, v := range vv {
 | 
			
		||||
					// http2 ResponseWriter mechanism to
 | 
			
		||||
					// send undeclared Trailers after the
 | 
			
		||||
| 
						 | 
				
			
			@ -249,6 +255,10 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
 | 
			
		|||
		ht.writeCommonHeaders(s)
 | 
			
		||||
		h := ht.rw.Header()
 | 
			
		||||
		for k, vv := range md {
 | 
			
		||||
			// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
 | 
			
		||||
			if isReservedHeader(k) {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			for _, v := range vv {
 | 
			
		||||
				h.Add(k, v)
 | 
			
		||||
			}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -35,7 +35,6 @@ package transport
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net"
 | 
			
		||||
| 
						 | 
				
			
			@ -89,7 +88,7 @@ type http2Client struct {
 | 
			
		|||
	// The scheme used: https if TLS is on, http otherwise.
 | 
			
		||||
	scheme string
 | 
			
		||||
 | 
			
		||||
	authCreds []credentials.Credentials
 | 
			
		||||
	creds []credentials.PerRPCCredentials
 | 
			
		||||
 | 
			
		||||
	mu            sync.Mutex     // guard the following variables
 | 
			
		||||
	state         transportState // the state of underlying connection
 | 
			
		||||
| 
						 | 
				
			
			@ -118,19 +117,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
 | 
			
		|||
		return nil, ConnectionErrorf("transport: %v", connErr)
 | 
			
		||||
	}
 | 
			
		||||
	var authInfo credentials.AuthInfo
 | 
			
		||||
	for _, c := range opts.AuthOptions {
 | 
			
		||||
		if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
 | 
			
		||||
			scheme = "https"
 | 
			
		||||
			// TODO(zhaoq): Now the first TransportAuthenticator is used if there are
 | 
			
		||||
			// multiple ones provided. Revisit this if it is not appropriate. Probably
 | 
			
		||||
			// place the ClientTransport construction into a separate function to make
 | 
			
		||||
			// things clear.
 | 
			
		||||
			if timeout > 0 {
 | 
			
		||||
				timeout -= time.Since(startT)
 | 
			
		||||
			}
 | 
			
		||||
			conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout)
 | 
			
		||||
			break
 | 
			
		||||
	if opts.TransportCredentials != nil {
 | 
			
		||||
		scheme = "https"
 | 
			
		||||
		if timeout > 0 {
 | 
			
		||||
			timeout -= time.Since(startT)
 | 
			
		||||
		}
 | 
			
		||||
		conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
 | 
			
		||||
	}
 | 
			
		||||
	if connErr != nil {
 | 
			
		||||
		return nil, ConnectionErrorf("transport: %v", connErr)
 | 
			
		||||
| 
						 | 
				
			
			@ -140,29 +132,6 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
 | 
			
		|||
			conn.Close()
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	// Send connection preface to server.
 | 
			
		||||
	n, err := conn.Write(clientPreface)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, ConnectionErrorf("transport: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if n != len(clientPreface) {
 | 
			
		||||
		return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
 | 
			
		||||
	}
 | 
			
		||||
	framer := newFramer(conn)
 | 
			
		||||
	if initialWindowSize != defaultWindowSize {
 | 
			
		||||
		err = framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)})
 | 
			
		||||
	} else {
 | 
			
		||||
		err = framer.writeSettings(true)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, ConnectionErrorf("transport: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	// Adjust the connection flow control window if needed.
 | 
			
		||||
	if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
 | 
			
		||||
		if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
 | 
			
		||||
			return nil, ConnectionErrorf("transport: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	ua := primaryUA
 | 
			
		||||
	if opts.UserAgent != "" {
 | 
			
		||||
		ua = opts.UserAgent + " " + ua
 | 
			
		||||
| 
						 | 
				
			
			@ -178,7 +147,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
 | 
			
		|||
		writableChan:    make(chan int, 1),
 | 
			
		||||
		shutdownChan:    make(chan struct{}),
 | 
			
		||||
		errorChan:       make(chan struct{}),
 | 
			
		||||
		framer:          framer,
 | 
			
		||||
		framer:          newFramer(conn),
 | 
			
		||||
		hBuf:            &buf,
 | 
			
		||||
		hEnc:            hpack.NewEncoder(&buf),
 | 
			
		||||
		controlBuf:      newRecvBuffer(),
 | 
			
		||||
| 
						 | 
				
			
			@ -187,17 +156,42 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
 | 
			
		|||
		scheme:          scheme,
 | 
			
		||||
		state:           reachable,
 | 
			
		||||
		activeStreams:   make(map[uint32]*Stream),
 | 
			
		||||
		authCreds:       opts.AuthOptions,
 | 
			
		||||
		creds:           opts.PerRPCCredentials,
 | 
			
		||||
		maxStreams:      math.MaxInt32,
 | 
			
		||||
		streamSendQuota: defaultWindowSize,
 | 
			
		||||
	}
 | 
			
		||||
	// Start the reader goroutine for incoming message. Each transport has
 | 
			
		||||
	// a dedicated goroutine which reads HTTP2 frame from network. Then it
 | 
			
		||||
	// dispatches the frame to the corresponding stream entity.
 | 
			
		||||
	go t.reader()
 | 
			
		||||
	// Send connection preface to server.
 | 
			
		||||
	n, err := t.conn.Write(clientPreface)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Close()
 | 
			
		||||
		return nil, ConnectionErrorf("transport: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if n != len(clientPreface) {
 | 
			
		||||
		t.Close()
 | 
			
		||||
		return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
 | 
			
		||||
	}
 | 
			
		||||
	if initialWindowSize != defaultWindowSize {
 | 
			
		||||
		err = t.framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)})
 | 
			
		||||
	} else {
 | 
			
		||||
		err = t.framer.writeSettings(true)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Close()
 | 
			
		||||
		return nil, ConnectionErrorf("transport: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	// Adjust the connection flow control window if needed.
 | 
			
		||||
	if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
 | 
			
		||||
		if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
 | 
			
		||||
			t.Close()
 | 
			
		||||
			return nil, ConnectionErrorf("transport: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	go t.controller()
 | 
			
		||||
	t.writableChan <- 0
 | 
			
		||||
	// Start the reader goroutine for incoming message. The threading model
 | 
			
		||||
	// on receiving is that each transport has a dedicated goroutine which
 | 
			
		||||
	// reads HTTP2 frame from network. Then it dispatches the frame to the
 | 
			
		||||
	// corresponding stream entity.
 | 
			
		||||
	go t.reader()
 | 
			
		||||
	return t, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -247,7 +241,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
 | 
			
		|||
	}
 | 
			
		||||
	ctx = peer.NewContext(ctx, pr)
 | 
			
		||||
	authData := make(map[string]string)
 | 
			
		||||
	for _, c := range t.authCreds {
 | 
			
		||||
	for _, c := range t.creds {
 | 
			
		||||
		// Construct URI required to get auth request metadata.
 | 
			
		||||
		var port string
 | 
			
		||||
		if pos := strings.LastIndex(t.target, ":"); pos != -1 {
 | 
			
		||||
| 
						 | 
				
			
			@ -270,6 +264,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
 | 
			
		|||
		}
 | 
			
		||||
	}
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
	if t.activeStreams == nil {
 | 
			
		||||
		t.mu.Unlock()
 | 
			
		||||
		return nil, ErrConnClosing
 | 
			
		||||
	}
 | 
			
		||||
	if t.state != reachable {
 | 
			
		||||
		t.mu.Unlock()
 | 
			
		||||
		return nil, ErrConnClosing
 | 
			
		||||
| 
						 | 
				
			
			@ -287,7 +285,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
 | 
			
		|||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
 | 
			
		||||
		// t.streamsQuota will be updated when t.CloseStream is invoked.
 | 
			
		||||
		// Return the quota back now because there is no stream returned to the caller.
 | 
			
		||||
		if _, ok := err.(StreamError); ok && checkStreamsQuota {
 | 
			
		||||
			t.streamsQuota.add(1)
 | 
			
		||||
		}
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
| 
						 | 
				
			
			@ -339,6 +340,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
 | 
			
		|||
	if md, ok := metadata.FromContext(ctx); ok {
 | 
			
		||||
		hasMD = true
 | 
			
		||||
		for k, v := range md {
 | 
			
		||||
			// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
 | 
			
		||||
			if isReservedHeader(k) {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			for _, entry := range v {
 | 
			
		||||
				t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
 | 
			
		||||
			}
 | 
			
		||||
| 
						 | 
				
			
			@ -388,9 +393,19 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
 | 
			
		|||
func (t *http2Client) CloseStream(s *Stream, err error) {
 | 
			
		||||
	var updateStreams bool
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
	if t.activeStreams == nil {
 | 
			
		||||
		t.mu.Unlock()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if t.streamsQuota != nil {
 | 
			
		||||
		updateStreams = true
 | 
			
		||||
	}
 | 
			
		||||
	if t.state == draining && len(t.activeStreams) == 1 {
 | 
			
		||||
		// The transport is draining and s is the last live stream on t.
 | 
			
		||||
		t.mu.Unlock()
 | 
			
		||||
		t.Close()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	delete(t.activeStreams, s.id)
 | 
			
		||||
	t.mu.Unlock()
 | 
			
		||||
	if updateStreams {
 | 
			
		||||
| 
						 | 
				
			
			@ -427,9 +442,12 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
 | 
			
		|||
// accessed any more.
 | 
			
		||||
func (t *http2Client) Close() (err error) {
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
	if t.state == reachable {
 | 
			
		||||
		close(t.errorChan)
 | 
			
		||||
	}
 | 
			
		||||
	if t.state == closing {
 | 
			
		||||
		t.mu.Unlock()
 | 
			
		||||
		return errors.New("transport: Close() was already called")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	t.state = closing
 | 
			
		||||
	t.mu.Unlock()
 | 
			
		||||
| 
						 | 
				
			
			@ -452,6 +470,25 @@ func (t *http2Client) Close() (err error) {
 | 
			
		|||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *http2Client) GracefulClose() error {
 | 
			
		||||
	t.mu.Lock()
 | 
			
		||||
	if t.state == closing {
 | 
			
		||||
		t.mu.Unlock()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	if t.state == draining {
 | 
			
		||||
		t.mu.Unlock()
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	t.state = draining
 | 
			
		||||
	active := len(t.activeStreams)
 | 
			
		||||
	t.mu.Unlock()
 | 
			
		||||
	if active == 0 {
 | 
			
		||||
		return t.Close()
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller
 | 
			
		||||
// should proceed only if Write returns nil.
 | 
			
		||||
// TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later
 | 
			
		||||
| 
						 | 
				
			
			@ -574,6 +611,11 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
 | 
			
		|||
// Window updates will deliver to the controller for sending when
 | 
			
		||||
// the cumulative quota exceeds the corresponding threshold.
 | 
			
		||||
func (t *http2Client) updateWindow(s *Stream, n uint32) {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
	if s.state == streamDone {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if w := t.fc.onRead(n); w > 0 {
 | 
			
		||||
		t.controlBuf.put(&windowUpdate{0, w})
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -303,6 +303,11 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
 | 
			
		|||
// Window updates will deliver to the controller for sending when
 | 
			
		||||
// the cumulative quota exceeds the corresponding threshold.
 | 
			
		||||
func (t *http2Server) updateWindow(s *Stream, n uint32) {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
	if s.state == streamDone {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if w := t.fc.onRead(n); w > 0 {
 | 
			
		||||
		t.controlBuf.put(&windowUpdate{0, w})
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -455,6 +460,10 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
 | 
			
		|||
		t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
 | 
			
		||||
	}
 | 
			
		||||
	for k, v := range md {
 | 
			
		||||
		if isReservedHeader(k) {
 | 
			
		||||
			// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		for _, entry := range v {
 | 
			
		||||
			t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			@ -497,6 +506,10 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
 | 
			
		|||
	t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc})
 | 
			
		||||
	// Attach the trailer metadata.
 | 
			
		||||
	for k, v := range s.trailer {
 | 
			
		||||
		// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
 | 
			
		||||
		if isReservedHeader(k) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		for _, entry := range v {
 | 
			
		||||
			t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -127,16 +127,40 @@ func isReservedHeader(hdr string) bool {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// isWhitelistedPseudoHeader checks whether hdr belongs to HTTP2 pseudoheaders
 | 
			
		||||
// that should be propagated into metadata visible to users.
 | 
			
		||||
func isWhitelistedPseudoHeader(hdr string) bool {
 | 
			
		||||
	switch hdr {
 | 
			
		||||
	case ":authority":
 | 
			
		||||
		return true
 | 
			
		||||
	default:
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *decodeState) setErr(err error) {
 | 
			
		||||
	if d.err == nil {
 | 
			
		||||
		d.err = err
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validContentType(t string) bool {
 | 
			
		||||
	e := "application/grpc"
 | 
			
		||||
	if !strings.HasPrefix(t, e) {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	// Support variations on the content-type
 | 
			
		||||
	// (e.g. "application/grpc+blah", "application/grpc;blah").
 | 
			
		||||
	if len(t) > len(e) && t[len(e)] != '+' && t[len(e)] != ';' {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *decodeState) processHeaderField(f hpack.HeaderField) {
 | 
			
		||||
	switch f.Name {
 | 
			
		||||
	case "content-type":
 | 
			
		||||
		if !strings.Contains(f.Value, "application/grpc") {
 | 
			
		||||
		if !validContentType(f.Value) {
 | 
			
		||||
			d.setErr(StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
| 
						 | 
				
			
			@ -162,7 +186,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
 | 
			
		|||
	case ":path":
 | 
			
		||||
		d.method = f.Value
 | 
			
		||||
	default:
 | 
			
		||||
		if !isReservedHeader(f.Name) {
 | 
			
		||||
		if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) {
 | 
			
		||||
			if f.Name == "user-agent" {
 | 
			
		||||
				i := strings.LastIndex(f.Value, " ")
 | 
			
		||||
				if i == -1 {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -321,6 +321,7 @@ const (
 | 
			
		|||
	reachable transportState = iota
 | 
			
		||||
	unreachable
 | 
			
		||||
	closing
 | 
			
		||||
	draining
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// NewServerTransport creates a ServerTransport with conn or non-nil error
 | 
			
		||||
| 
						 | 
				
			
			@ -335,9 +336,11 @@ type ConnectOptions struct {
 | 
			
		|||
	UserAgent string
 | 
			
		||||
	// Dialer specifies how to dial a network address.
 | 
			
		||||
	Dialer func(string, time.Duration) (net.Conn, error)
 | 
			
		||||
	// AuthOptions stores the credentials required to setup a client connection and/or issue RPCs.
 | 
			
		||||
	AuthOptions []credentials.Credentials
 | 
			
		||||
	// Timeout specifies the timeout for dialing a client connection.
 | 
			
		||||
	// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
 | 
			
		||||
	PerRPCCredentials []credentials.PerRPCCredentials
 | 
			
		||||
	// TransportCredentials stores the Authenticator required to setup a client connection.
 | 
			
		||||
	TransportCredentials credentials.TransportCredentials
 | 
			
		||||
	// Timeout specifies the timeout for dialing a ClientTransport.
 | 
			
		||||
	Timeout time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -391,6 +394,10 @@ type ClientTransport interface {
 | 
			
		|||
	// is called only once.
 | 
			
		||||
	Close() error
 | 
			
		||||
 | 
			
		||||
	// GracefulClose starts to tear down the transport. It stops accepting
 | 
			
		||||
	// new RPCs and wait the completion of the pending RPCs.
 | 
			
		||||
	GracefulClose() error
 | 
			
		||||
 | 
			
		||||
	// Write sends the data for the given stream. A nil stream indicates
 | 
			
		||||
	// the write is to be performed on the transport as a whole.
 | 
			
		||||
	Write(s *Stream, data []byte, opts *Options) error
 | 
			
		||||
| 
						 | 
				
			
			@ -468,7 +475,7 @@ func (e ConnectionError) Error() string {
 | 
			
		|||
	return fmt.Sprintf("connection error: desc = %q", e.Desc)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Define some common ConnectionErrors.
 | 
			
		||||
// ErrConnClosing indicates that the transport is closing.
 | 
			
		||||
var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
 | 
			
		||||
 | 
			
		||||
// StreamError is an error that only affects one stream within a connection.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue