Allow gRPC clients to connect to multiple backends (#1918)

Fixes #1917 and #1755, also updates google.golang.org/grpc to b60d3e9e.
This commit is contained in:
Roland Bracewell Shoemaker 2016-06-15 16:50:56 -07:00 committed by Jacob Hoffman-Andrews
parent f04b922aff
commit 92e0704b1b
27 changed files with 1303 additions and 671 deletions

18
Godeps/Godeps.json generated
View File

@ -200,39 +200,39 @@
}, },
{ {
"ImportPath": "google.golang.org/grpc", "ImportPath": "google.golang.org/grpc",
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a" "Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
}, },
{ {
"ImportPath": "google.golang.org/grpc/codes", "ImportPath": "google.golang.org/grpc/codes",
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a" "Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
}, },
{ {
"ImportPath": "google.golang.org/grpc/credentials", "ImportPath": "google.golang.org/grpc/credentials",
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a" "Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
}, },
{ {
"ImportPath": "google.golang.org/grpc/grpclog", "ImportPath": "google.golang.org/grpc/grpclog",
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a" "Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
}, },
{ {
"ImportPath": "google.golang.org/grpc/internal", "ImportPath": "google.golang.org/grpc/internal",
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a" "Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
}, },
{ {
"ImportPath": "google.golang.org/grpc/metadata", "ImportPath": "google.golang.org/grpc/metadata",
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a" "Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
}, },
{ {
"ImportPath": "google.golang.org/grpc/naming", "ImportPath": "google.golang.org/grpc/naming",
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a" "Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
}, },
{ {
"ImportPath": "google.golang.org/grpc/peer", "ImportPath": "google.golang.org/grpc/peer",
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a" "Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
}, },
{ {
"ImportPath": "google.golang.org/grpc/transport", "ImportPath": "google.golang.org/grpc/transport",
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a" "Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
}, },
{ {
"ImportPath": "gopkg.in/gorp.v1", "ImportPath": "gopkg.in/gorp.v1",

View File

@ -13,14 +13,14 @@ import (
) )
func main() { func main() {
addr := flag.String("addr", "127.0.0.1:9090", "CCS address") addr := flag.String("addr", "boulder:9090", "CCS address")
name := flag.String("name", "", "Name to check") name := flag.String("name", "", "Name to check")
issuer := flag.String("issuerDomain", "", "Issuer domain to check against") issuer := flag.String("issuerDomain", "", "Issuer domain to check against")
flag.Parse() flag.Parse()
// Set up a connection to the server. // Set up a connection to the server.
conn, err := bgrpc.ClientSetup(&cmd.GRPCClientConfig{ conn, err := bgrpc.ClientSetup(&cmd.GRPCClientConfig{
ServerAddress: *addr, ServerAddresses: []string{*addr},
ServerIssuerPath: "test/grpc-creds/ca.pem", ServerIssuerPath: "test/grpc-creds/ca.pem",
ClientCertificatePath: "test/grpc-creds/client.pem", ClientCertificatePath: "test/grpc-creds/client.pem",
ClientKeyPath: "test/grpc-creds/key.pem", ClientKeyPath: "test/grpc-creds/key.pem",

View File

@ -511,7 +511,7 @@ type LogDescription struct {
// GRPCClientConfig contains the information needed to talk to the gRPC service // GRPCClientConfig contains the information needed to talk to the gRPC service
type GRPCClientConfig struct { type GRPCClientConfig struct {
ServerAddress string ServerAddresses []string
ServerIssuerPath string ServerIssuerPath string
ClientCertificatePath string ClientCertificatePath string
ClientKeyPath string ClientKeyPath string

47
grpc/balancer.go Normal file
View File

@ -0,0 +1,47 @@
package grpc
import (
"google.golang.org/grpc/naming"
)
// staticResolver implements both the naming.Resolver and naming.Watcher
// interfaces. It always returns a single static list then blocks forever
type staticResolver struct {
addresses []*naming.Update
}
func newStaticResolver(addresses []string) *staticResolver {
sr := &staticResolver{}
for _, a := range addresses {
sr.addresses = append(sr.addresses, &naming.Update{
Op: naming.Add,
Addr: a,
})
}
return sr
}
// Resolve just returns the staticResolver it was called from as it satisfies
// both the naming.Resolver and naming.Watcher interfaces
func (sr *staticResolver) Resolve(target string) (naming.Watcher, error) {
return sr, nil
}
// Next is called in a loop by grpc.RoundRobin expecting updates to which addresses are
// appropriate. Since we just want to return a static list once return a list on the first
// call then block forever on the second instead of sitting in a tight loop
func (sr *staticResolver) Next() ([]*naming.Update, error) {
if sr.addresses != nil {
addrs := sr.addresses
sr.addresses = nil
return addrs, nil
}
// Since staticResolver.Next is called in a tight loop block forever
// after returning the initial set of addresses
forever := make(chan struct{})
<-forever
return nil, nil
}
// Close does nothing
func (sr *staticResolver) Close() {}

39
grpc/balancer_test.go Normal file
View File

@ -0,0 +1,39 @@
package grpc
import (
"testing"
"time"
"google.golang.org/grpc/naming"
"github.com/letsencrypt/boulder/test"
)
func TestStaticResolver(t *testing.T) {
names := []string{"test:443"}
sr := newStaticResolver(names)
watcher, err := sr.Resolve("")
test.AssertNotError(t, err, "staticResolver.Resolve failed")
// Make sure doing this doesn't break anything (since it does nothing)
watcher.Close()
updates, err := watcher.Next()
test.AssertNotError(t, err, "staticwatcher.Next failed")
test.AssertEquals(t, len(names), len(updates))
test.AssertEquals(t, updates[0].Addr, "test:443")
test.AssertEquals(t, updates[0].Op, naming.Add)
test.AssertEquals(t, updates[0].Metadata, nil)
returned := make(chan struct{}, 1)
go func() {
_, err = watcher.Next()
test.AssertNotError(t, err, "watcher.Next failed")
returned <- struct{}{}
}()
select {
case <-returned:
t.Fatal("staticWatcher.Next returned something after the first call")
case <-time.After(time.Millisecond * 500):
}
}

77
grpc/creds/creds.go Normal file
View File

@ -0,0 +1,77 @@
package creds
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/credentials"
)
// transportCredentials is a grpc/credentials.TransportCredentials which supports
// connecting to, and verifying multiple DNS names
type transportCredentials struct {
roots *x509.CertPool
clients []tls.Certificate
}
// New returns a new initialized grpc/credentials.TransportCredentials
func New(rootCAs *x509.CertPool, clientCerts []tls.Certificate) credentials.TransportCredentials {
return &transportCredentials{rootCAs, clientCerts}
}
// ClientHandshake performs the TLS handshake for a client -> server connection
func (tc *transportCredentials) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, credentials.AuthInfo, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, nil, err
}
conn := tls.Client(rawConn, &tls.Config{
ServerName: host,
RootCAs: tc.roots,
Certificates: tc.clients,
MinVersion: tls.VersionTLS12, // Override default of tls.VersionTLS10
MaxVersion: tls.VersionTLS12, // Same as default in golang <= 1.6
})
errChan := make(chan error, 1)
go func() {
errChan <- conn.Handshake()
}()
select {
case <-time.After(timeout):
return nil, nil, errors.New("boulder/grpc/creds: TLS handshake timed out")
case err := <-errChan:
if err != nil {
_ = rawConn.Close()
return nil, nil, fmt.Errorf("boulder/grpc/creds: TLS handshake failed: %s", err)
}
return conn, nil, nil
}
}
// ServerHandshake performs the TLS handshake for a server <- client connection
func (tc *transportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return nil, nil, fmt.Errorf("boulder/grpc/creds: Server-side handshakes are not implemented")
}
// Info returns information about the transport protocol used
func (tc *transportCredentials) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{
SecurityProtocol: "tls",
SecurityVersion: "1.2", // We *only* support TLS 1.2
}
}
// GetRequestMetadata returns nil, nil since TLS credentials do not have metadata.
func (tc *transportCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return nil, nil
}
// RequireTransportSecurity always returns true because TLS is transport security
func (tc *transportCredentials) RequireTransportSecurity() bool {
return true
}

103
grpc/creds/creds_test.go Normal file
View File

@ -0,0 +1,103 @@
package creds
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"net/http/httptest"
"testing"
"time"
"github.com/letsencrypt/boulder/test"
)
func TestTransportCredentials(t *testing.T) {
priv, err := rsa.GenerateKey(rand.Reader, 1024)
test.AssertNotError(t, err, "rsa.GenerateKey failed")
temp := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "A",
},
NotBefore: time.Unix(1000, 0),
NotAfter: time.Now().AddDate(1, 0, 0),
BasicConstraintsValid: true,
IsCA: true,
}
derA, err := x509.CreateCertificate(rand.Reader, temp, temp, priv.Public(), priv)
test.AssertNotError(t, err, "x509.CreateCertificate failed")
certA, err := x509.ParseCertificate(derA)
test.AssertNotError(t, err, "x509.ParserCertificate failed")
temp.Subject.CommonName = "B"
derB, err := x509.CreateCertificate(rand.Reader, temp, temp, priv.Public(), priv)
test.AssertNotError(t, err, "x509.CreateCertificate failed")
certB, err := x509.ParseCertificate(derB)
test.AssertNotError(t, err, "x509.ParserCertificate failed")
roots := x509.NewCertPool()
roots.AddCert(certA)
roots.AddCert(certB)
serverA := httptest.NewUnstartedServer(nil)
serverA.TLS = &tls.Config{Certificates: []tls.Certificate{{Certificate: [][]byte{derA}, PrivateKey: priv}}}
serverB := httptest.NewUnstartedServer(nil)
serverB.TLS = &tls.Config{Certificates: []tls.Certificate{{Certificate: [][]byte{derB}, PrivateKey: priv}}}
tc := New(roots, nil)
serverA.StartTLS()
defer serverA.Close()
addrA := serverA.Listener.Addr().String()
rawConnA, err := net.Dial("tcp", addrA)
test.AssertNotError(t, err, "net.Dial failed")
defer func() {
_ = rawConnA.Close()
}()
conn, _, err := tc.ClientHandshake("A:2020", rawConnA, time.Second)
test.AssertNotError(t, err, "tc.ClientHandshake failed")
test.Assert(t, conn != nil, "tc.ClientHandshake returned a nil net.Conn")
serverB.StartTLS()
defer serverB.Close()
addrB := serverB.Listener.Addr().String()
rawConnB, err := net.Dial("tcp", addrB)
test.AssertNotError(t, err, "net.Dial failed")
defer func() {
_ = rawConnB.Close()
}()
conn, _, err = tc.ClientHandshake("B:3030", rawConnB, time.Second)
test.AssertNotError(t, err, "tc.ClientHandshake failed")
test.Assert(t, conn != nil, "tc.ClientHandshake returned a nil net.Conn")
// Test timeout
ln, err := net.Listen("tcp", "127.0.0.1:0")
test.AssertNotError(t, err, "net.Listen failed")
defer func() {
_ = ln.Close()
}()
addrC := ln.Addr().String()
go func() {
for {
_, err := ln.Accept()
test.AssertNotError(t, err, "ln.Accept failed")
time.Sleep(time.Second)
}
}()
rawConnC, err := net.Dial("tcp", addrC)
test.AssertNotError(t, err, "net.Dial failed")
defer func() {
_ = rawConnB.Close()
}()
conn, _, err = tc.ClientHandshake("A:2020", rawConnC, time.Millisecond)
test.AssertError(t, err, "tc.ClientHandshake didn't timeout")
test.AssertEquals(t, err.Error(), "boulder/grpc/creds: TLS handshake timed out")
test.Assert(t, conn == nil, "tc.ClientHandshake returned a non-nil net.Conn on failure")
}

View File

@ -12,17 +12,21 @@ import (
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"github.com/letsencrypt/boulder/cmd" "github.com/letsencrypt/boulder/cmd"
bcreds "github.com/letsencrypt/boulder/grpc/creds"
) )
// CodedError is a alias required to appease go vet // CodedError is a alias required to appease go vet
var CodedError = grpc.Errorf var CodedError = grpc.Errorf
// ClientSetup loads various TLS certificates and creates a // ClientSetup loads various TLS certificates and creates a
// gRPC TransportAuthenticator that presents the client certificate // gRPC TransportCredentials that presents the client certificate
// and validates the certificate presented by the server is for a // and validates the certificate presented by the server is for a
// specific hostname and issued by the provided issuer certificate // specific hostname and issued by the provided issuer certificate
// thens dials and returns a grpc.ClientConn to the remote service. // thens dials and returns a grpc.ClientConn to the remote service.
func ClientSetup(c *cmd.GRPCClientConfig) (*grpc.ClientConn, error) { func ClientSetup(c *cmd.GRPCClientConfig) (*grpc.ClientConn, error) {
if len(c.ServerAddresses) == 0 {
return nil, fmt.Errorf("boulder/grpc: ServerAddresses is empty")
}
serverIssuerBytes, err := ioutil.ReadFile(c.ServerIssuerPath) serverIssuerBytes, err := ioutil.ReadFile(c.ServerIssuerPath)
if err != nil { if err != nil {
return nil, err return nil, err
@ -35,15 +39,11 @@ func ClientSetup(c *cmd.GRPCClientConfig) (*grpc.ClientConn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
host, _, err := net.SplitHostPort(c.ServerAddress) return grpc.Dial(
if err != nil { "", // Since our staticResolver provides addresses we don't need to pass an address here
return nil, err grpc.WithTransportCredentials(bcreds.New(rootCAs, []tls.Certificate{clientCert})),
} grpc.WithBalancer(grpc.RoundRobin(newStaticResolver(c.ServerAddresses))),
return grpc.Dial(c.ServerAddress, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ )
ServerName: host,
RootCAs: rootCAs,
Certificates: []tls.Certificate{clientCert},
})))
} }
// NewServer loads various TLS certificates and creates a // NewServer loads various TLS certificates and creates a

View File

@ -137,7 +137,7 @@
}, },
"maxConcurrentRPCServerRequests": 16, "maxConcurrentRPCServerRequests": 16,
"publisherService": { "publisherService": {
"serverAddress": "boulder:9091", "serverAddresses": ["boulder:9091"],
"serverIssuerPath": "test/grpc-creds/ca.pem", "serverIssuerPath": "test/grpc-creds/ca.pem",
"clientCertificatePath": "test/grpc-creds/client.pem", "clientCertificatePath": "test/grpc-creds/client.pem",
"clientKeyPath": "test/grpc-creds/key.pem", "clientKeyPath": "test/grpc-creds/key.pem",
@ -173,7 +173,7 @@
"doNotForceCN": true, "doNotForceCN": true,
"reuseValidAuthz": true, "reuseValidAuthz": true,
"vaService": { "vaService": {
"serverAddress": "boulder:9092", "serverAddresses": ["boulder:9092"],
"serverIssuerPath": "test/grpc-creds/ca.pem", "serverIssuerPath": "test/grpc-creds/ca.pem",
"clientCertificatePath": "test/grpc-creds/client.pem", "clientCertificatePath": "test/grpc-creds/client.pem",
"clientKeyPath": "test/grpc-creds/key.pem", "clientKeyPath": "test/grpc-creds/key.pem",
@ -224,7 +224,7 @@
"dnsTries": 3, "dnsTries": 3,
"issuerDomain": "happy-hacker-ca.invalid", "issuerDomain": "happy-hacker-ca.invalid",
"caaService": { "caaService": {
"serverAddress": "boulder:9090", "serverAddresses": ["boulder:9090"],
"serverIssuerPath": "test/grpc-creds/ca.pem", "serverIssuerPath": "test/grpc-creds/ca.pem",
"clientCertificatePath": "test/grpc-creds/client.pem", "clientCertificatePath": "test/grpc-creds/client.pem",
"clientKeyPath": "test/grpc-creds/key.pem" "clientKeyPath": "test/grpc-creds/key.pem"
@ -297,7 +297,7 @@
"signFailureBackoffMax": "30m", "signFailureBackoffMax": "30m",
"debugAddr": "localhost:8006", "debugAddr": "localhost:8006",
"publisher": { "publisher": {
"serverAddress": "boulder:9091", "serverAddresses": ["boulder:9091"],
"serverIssuerPath": "test/grpc-creds/ca.pem", "serverIssuerPath": "test/grpc-creds/ca.pem",
"clientCertificatePath": "test/grpc-creds/client.pem", "clientCertificatePath": "test/grpc-creds/client.pem",
"clientKeyPath": "test/grpc-creds/key.pem", "clientKeyPath": "test/grpc-creds/key.pem",

View File

@ -21,8 +21,9 @@ proto:
exit 1; \ exit 1; \
fi fi
go get -u -v github.com/golang/protobuf/protoc-gen-go go get -u -v github.com/golang/protobuf/protoc-gen-go
for file in $$(git ls-files '*.proto'); do \ # use $$dir as the root for all proto files in the same directory
protoc -I $$(dirname $$file) --go_out=plugins=grpc:$$(dirname $$file) $$file; \ for dir in $$(git ls-files '*.proto' | xargs -n1 dirname | uniq); do \
protoc -I $$dir --go_out=plugins=grpc:$$dir $$dir/*.proto; \
done done
test: testdeps test: testdeps

View File

@ -1,22 +1,22 @@
Additional IP Rights Grant (Patents) Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by "This implementation" means the copyrightable works distributed by
Google as part of the GRPC project. Google as part of the gRPC project.
Google hereby grants to You a perpetual, worldwide, non-exclusive, Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section) no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import, patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this transfer and otherwise run, modify and propagate the contents of this
implementation of GRPC, where such license applies only to those patent implementation of gRPC, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this the future, licensable by Google that are necessarily infringed by this
implementation of GRPC. This grant does not include claims that would be implementation of gRPC. This grant does not include claims that would be
infringed only as a consequence of further modification of this infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of GRPC or any code incorporated within this that this implementation of gRPC or any code incorporated within this
implementation of GRPC constitutes direct or contributory patent implementation of gRPC constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of GRPC rights granted to you under this License for this implementation of gRPC
shall terminate as of the date such litigation is filed. shall terminate as of the date such litigation is filed.

View File

@ -19,7 +19,7 @@ var (
// backoffStrategy defines the methodology for backing off after a grpc // backoffStrategy defines the methodology for backing off after a grpc
// connection failure. // connection failure.
// //
// This is unexported until the GRPC project decides whether or not to allow // This is unexported until the gRPC project decides whether or not to allow
// alternative backoff strategies. Once a decision is made, this type and its // alternative backoff strategies. Once a decision is made, this type and its
// method may be exported. // method may be exported.
type backoffStrategy interface { type backoffStrategy interface {
@ -28,14 +28,14 @@ type backoffStrategy interface {
backoff(retries int) time.Duration backoff(retries int) time.Duration
} }
// BackoffConfig defines the parameters for the default GRPC backoff strategy. // BackoffConfig defines the parameters for the default gRPC backoff strategy.
type BackoffConfig struct { type BackoffConfig struct {
// MaxDelay is the upper bound of backoff delay. // MaxDelay is the upper bound of backoff delay.
MaxDelay time.Duration MaxDelay time.Duration
// TODO(stevvooe): The following fields are not exported, as allowing // TODO(stevvooe): The following fields are not exported, as allowing
// changes would violate the current GRPC specification for backoff. If // changes would violate the current gRPC specification for backoff. If
// GRPC decides to allow more interesting backoff strategies, these fields // gRPC decides to allow more interesting backoff strategies, these fields
// may be opened up in the future. // may be opened up in the future.
// baseDelay is the amount of time to wait before retrying after the first // baseDelay is the amount of time to wait before retrying after the first

340
vendor/google.golang.org/grpc/balancer.go generated vendored Normal file
View File

@ -0,0 +1,340 @@
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package grpc
import (
"fmt"
"sync"
"golang.org/x/net/context"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/naming"
"google.golang.org/grpc/transport"
)
// Address represents a server the client connects to.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
type Address struct {
// Addr is the server address on which a connection will be established.
Addr string
// Metadata is the information associated with Addr, which may be used
// to make load balancing decision.
Metadata interface{}
}
// BalancerGetOptions configures a Get call.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
type BalancerGetOptions struct {
// BlockingWait specifies whether Get should block when there is no
// connected address.
BlockingWait bool
}
// Balancer chooses network addresses for RPCs.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
type Balancer interface {
// Start does the initialization work to bootstrap a Balancer. For example,
// this function may start the name resolution and watch the updates. It will
// be called when dialing.
Start(target string) error
// Up informs the Balancer that gRPC has a connection to the server at
// addr. It returns down which is called once the connection to addr gets
// lost or closed.
// TODO: It is not clear how to construct and take advantage the meaningful error
// parameter for down. Need realistic demands to guide.
Up(addr Address) (down func(error))
// Get gets the address of a server for the RPC corresponding to ctx.
// i) If it returns a connected address, gRPC internals issues the RPC on the
// connection to this address;
// ii) If it returns an address on which the connection is under construction
// (initiated by Notify(...)) but not connected, gRPC internals
// * fails RPC if the RPC is fail-fast and connection is in the TransientFailure or
// Shutdown state;
// or
// * issues RPC on the connection otherwise.
// iii) If it returns an address on which the connection does not exist, gRPC
// internals treats it as an error and will fail the corresponding RPC.
//
// Therefore, the following is the recommended rule when writing a custom Balancer.
// If opts.BlockingWait is true, it should return a connected address or
// block if there is no connected address. It should respect the timeout or
// cancellation of ctx when blocking. If opts.BlockingWait is false (for fail-fast
// RPCs), it should return an address it has notified via Notify(...) immediately
// instead of blocking.
//
// The function returns put which is called once the rpc has completed or failed.
// put can collect and report RPC stats to a remote load balancer. gRPC internals
// will try to call this again if err is non-nil (unless err is ErrClientConnClosing).
//
// TODO: Add other non-recoverable errors?
Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error)
// Notify returns a channel that is used by gRPC internals to watch the addresses
// gRPC needs to connect. The addresses might be from a name resolver or remote
// load balancer. gRPC internals will compare it with the existing connected
// addresses. If the address Balancer notified is not in the existing connected
// addresses, gRPC starts to connect the address. If an address in the existing
// connected addresses is not in the notification list, the corresponding connection
// is shutdown gracefully. Otherwise, there are no operations to take. Note that
// the Address slice must be the full list of the Addresses which should be connected.
// It is NOT delta.
Notify() <-chan []Address
// Close shuts down the balancer.
Close() error
}
// downErr implements net.Error. It is constructed by gRPC internals and passed to the down
// call of Balancer.
type downErr struct {
timeout bool
temporary bool
desc string
}
func (e downErr) Error() string { return e.desc }
func (e downErr) Timeout() bool { return e.timeout }
func (e downErr) Temporary() bool { return e.temporary }
func downErrorf(timeout, temporary bool, format string, a ...interface{}) downErr {
return downErr{
timeout: timeout,
temporary: temporary,
desc: fmt.Sprintf(format, a...),
}
}
// RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch
// the name resolution updates and updates the addresses available correspondingly.
func RoundRobin(r naming.Resolver) Balancer {
return &roundRobin{r: r}
}
type roundRobin struct {
r naming.Resolver
w naming.Watcher
open []Address // all the addresses the client should potentially connect
mu sync.Mutex
addrCh chan []Address // the channel to notify gRPC internals the list of addresses the client should connect to.
connected []Address // all the connected addresses
next int // index of the next address to return for Get()
waitCh chan struct{} // the channel to block when there is no connected address available
done bool // The Balancer is closed.
}
func (rr *roundRobin) watchAddrUpdates() error {
updates, err := rr.w.Next()
if err != nil {
grpclog.Println("grpc: the naming watcher stops working due to %v.", err)
return err
}
rr.mu.Lock()
defer rr.mu.Unlock()
for _, update := range updates {
addr := Address{
Addr: update.Addr,
}
switch update.Op {
case naming.Add:
var exist bool
for _, v := range rr.open {
if addr == v {
exist = true
grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr)
break
}
}
if exist {
continue
}
rr.open = append(rr.open, addr)
case naming.Delete:
for i, v := range rr.open {
if v == addr {
copy(rr.open[i:], rr.open[i+1:])
rr.open = rr.open[:len(rr.open)-1]
break
}
}
default:
grpclog.Println("Unknown update.Op ", update.Op)
}
}
// Make a copy of rr.open and write it onto rr.addrCh so that gRPC internals gets notified.
open := make([]Address, len(rr.open), len(rr.open))
copy(open, rr.open)
if rr.done {
return ErrClientConnClosing
}
rr.addrCh <- open
return nil
}
func (rr *roundRobin) Start(target string) error {
if rr.r == nil {
// If there is no name resolver installed, it is not needed to
// do name resolution. In this case, rr.addrCh stays nil.
return nil
}
w, err := rr.r.Resolve(target)
if err != nil {
return err
}
rr.w = w
rr.addrCh = make(chan []Address)
go func() {
for {
if err := rr.watchAddrUpdates(); err != nil {
return
}
}
}()
return nil
}
// Up appends addr to the end of rr.connected and sends notification if there
// are pending Get() calls.
func (rr *roundRobin) Up(addr Address) func(error) {
rr.mu.Lock()
defer rr.mu.Unlock()
for _, a := range rr.connected {
if a == addr {
return nil
}
}
rr.connected = append(rr.connected, addr)
if len(rr.connected) == 1 {
// addr is only one available. Notify the Get() callers who are blocking.
if rr.waitCh != nil {
close(rr.waitCh)
rr.waitCh = nil
}
}
return func(err error) {
rr.down(addr, err)
}
}
// down removes addr from rr.connected and moves the remaining addrs forward.
func (rr *roundRobin) down(addr Address, err error) {
rr.mu.Lock()
defer rr.mu.Unlock()
for i, a := range rr.connected {
if a == addr {
copy(rr.connected[i:], rr.connected[i+1:])
rr.connected = rr.connected[:len(rr.connected)-1]
return
}
}
}
// Get returns the next addr in the rotation.
func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) {
var ch chan struct{}
rr.mu.Lock()
if rr.done {
rr.mu.Unlock()
err = ErrClientConnClosing
return
}
if rr.next >= len(rr.connected) {
rr.next = 0
}
if len(rr.connected) > 0 {
addr = rr.connected[rr.next]
rr.next++
rr.mu.Unlock()
return
}
// There is no address available. Wait on rr.waitCh.
// TODO(zhaoq): Handle the case when opts.BlockingWait is false.
if rr.waitCh == nil {
ch = make(chan struct{})
rr.waitCh = ch
} else {
ch = rr.waitCh
}
rr.mu.Unlock()
for {
select {
case <-ctx.Done():
err = transport.ContextErr(ctx.Err())
return
case <-ch:
rr.mu.Lock()
if rr.done {
rr.mu.Unlock()
err = ErrClientConnClosing
return
}
if len(rr.connected) == 0 {
// The newly added addr got removed by Down() again.
if rr.waitCh == nil {
ch = make(chan struct{})
rr.waitCh = ch
} else {
ch = rr.waitCh
}
rr.mu.Unlock()
continue
}
if rr.next >= len(rr.connected) {
rr.next = 0
}
addr = rr.connected[rr.next]
rr.next++
rr.mu.Unlock()
return
}
}
}
func (rr *roundRobin) Notify() <-chan []Address {
return rr.addrCh
}
func (rr *roundRobin) Close() error {
rr.mu.Lock()
defer rr.mu.Unlock()
rr.done = true
if rr.w != nil {
rr.w.Close()
}
if rr.waitCh != nil {
close(rr.waitCh)
rr.waitCh = nil
}
if rr.addrCh != nil {
close(rr.addrCh)
}
return nil
}

View File

@ -132,19 +132,16 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
Last: true, Last: true,
Delay: false, Delay: false,
} }
var (
lastErr error // record the error that happened
)
for { for {
var ( var (
err error err error
t transport.ClientTransport t transport.ClientTransport
stream *transport.Stream stream *transport.Stream
// Record the put handler from Balancer.Get(...). It is called once the
// RPC has completed or failed.
put func()
) )
// TODO(zhaoq): Need a formal spec of retry strategy for non-failfast rpcs. // TODO(zhaoq): Need a formal spec of fail-fast.
if lastErr != nil && c.failFast {
return toRPCErr(lastErr)
}
callHdr := &transport.CallHdr{ callHdr := &transport.CallHdr{
Host: cc.authority, Host: cc.authority,
Method: method, Method: method,
@ -152,39 +149,66 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if cc.dopts.cp != nil { if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
} }
t, err = cc.dopts.picker.Pick(ctx) gopts := BalancerGetOptions{
BlockingWait: !c.failFast,
}
t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
if lastErr != nil { // TODO(zhaoq): Probably revisit the error handling.
// This was a retry; return the error from the last attempt. if err == ErrClientConnClosing {
return toRPCErr(lastErr) return Errorf(codes.FailedPrecondition, "%v", err)
} }
return toRPCErr(err) if _, ok := err.(transport.StreamError); ok {
return toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); ok {
if c.failFast {
return toRPCErr(err)
}
}
// All the remaining cases are treated as retryable.
continue
} }
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
} }
stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts) stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts)
if err != nil { if err != nil {
if _, ok := err.(transport.ConnectionError); ok { if put != nil {
lastErr = err put()
continue put = nil
} }
if lastErr != nil { if _, ok := err.(transport.ConnectionError); ok {
return toRPCErr(lastErr) if c.failFast {
return toRPCErr(err)
}
continue
} }
return toRPCErr(err) return toRPCErr(err)
} }
// Receive the response // Receive the response
lastErr = recvResponse(cc.dopts, t, &c, stream, reply) err = recvResponse(cc.dopts, t, &c, stream, reply)
if _, ok := lastErr.(transport.ConnectionError); ok { if err != nil {
continue if put != nil {
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok {
if c.failFast {
return toRPCErr(err)
}
continue
}
t.CloseStream(stream, err)
return toRPCErr(err)
} }
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true)
} }
t.CloseStream(stream, lastErr) t.CloseStream(stream, nil)
if lastErr != nil { if put != nil {
return toRPCErr(lastErr) put()
put = nil
} }
return Errorf(stream.StatusCode(), "%s", stream.StatusDesc()) return Errorf(stream.StatusCode(), "%s", stream.StatusDesc())
} }

View File

@ -43,28 +43,38 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
) )
var ( var (
// ErrUnspecTarget indicates that the target address is unspecified. // ErrClientConnClosing indicates that the operation is illegal because
ErrUnspecTarget = errors.New("grpc: target is unspecified") // the ClientConn is closing.
// ErrNoTransportSecurity indicates that there is no transport security ErrClientConnClosing = errors.New("grpc: the client connection is closing")
// ErrClientConnTimeout indicates that the ClientConn cannot establish the
// underlying connections within the specified timeout.
ErrClientConnTimeout = errors.New("grpc: timed out when dialing")
// errNoTransportSecurity indicates that there is no transport security
// being set for ClientConn. Users should either set one or explicitly // being set for ClientConn. Users should either set one or explicitly
// call WithInsecure DialOption to disable security. // call WithInsecure DialOption to disable security.
ErrNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)") errNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)")
// ErrCredentialsMisuse indicates that users want to transmit security information // errTransportCredentialsMissing indicates that users want to transmit security
// (e.g., oauth2 token) which requires secure connection on an insecure // information (e.g., oauth2 token) which requires secure connection on an insecure
// connection. // connection.
ErrCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)") errTransportCredentialsMissing = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportCredentials() to set)")
// ErrClientConnClosing indicates that the operation is illegal because // errCredentialsConflict indicates that grpc.WithTransportCredentials()
// the session is closing. // and grpc.WithInsecure() are both called for a connection.
ErrClientConnClosing = errors.New("grpc: the client connection is closing") errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
// ErrClientConnTimeout indicates that the connection could not be // errNetworkIP indicates that the connection is down due to some network I/O error.
// established or re-established within the specified timeout. errNetworkIO = errors.New("grpc: failed with network I/O error")
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect") // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing.
errConnClosing = errors.New("grpc: the connection is closing")
errNoAddr = errors.New("grpc: there is no address available to dial")
// minimum time to give a connection to complete // minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second minConnectTimeout = 20 * time.Second
) )
@ -76,9 +86,10 @@ type dialOptions struct {
cp Compressor cp Compressor
dc Decompressor dc Decompressor
bs backoffStrategy bs backoffStrategy
picker Picker balancer Balancer
block bool block bool
insecure bool insecure bool
timeout time.Duration
copts transport.ConnectOptions copts transport.ConnectOptions
} }
@ -108,10 +119,10 @@ func WithDecompressor(dc Decompressor) DialOption {
} }
} }
// WithPicker returns a DialOption which sets a picker for connection selection. // WithBalancer returns a DialOption which sets a load balancer.
func WithPicker(p Picker) DialOption { func WithBalancer(b Balancer) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.picker = p o.balancer = b
} }
} }
@ -136,7 +147,7 @@ func WithBackoffConfig(b BackoffConfig) DialOption {
// withBackoff sets the backoff strategy used for retries after a // withBackoff sets the backoff strategy used for retries after a
// failed connection attempt. // failed connection attempt.
// //
// This can be exported if arbitrary backoff strategies are allowed by GRPC. // This can be exported if arbitrary backoff strategies are allowed by gRPC.
func withBackoff(bs backoffStrategy) DialOption { func withBackoff(bs backoffStrategy) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.bs = bs o.bs = bs
@ -162,24 +173,25 @@ func WithInsecure() DialOption {
// WithTransportCredentials returns a DialOption which configures a // WithTransportCredentials returns a DialOption which configures a
// connection level security credentials (e.g., TLS/SSL). // connection level security credentials (e.g., TLS/SSL).
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { func WithTransportCredentials(creds credentials.TransportCredentials) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.copts.AuthOptions = append(o.copts.AuthOptions, creds) o.copts.TransportCredentials = creds
} }
} }
// WithPerRPCCredentials returns a DialOption which sets // WithPerRPCCredentials returns a DialOption which sets
// credentials which will place auth state on each outbound RPC. // credentials which will place auth state on each outbound RPC.
func WithPerRPCCredentials(creds credentials.Credentials) DialOption { func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.copts.AuthOptions = append(o.copts.AuthOptions, creds) o.copts.PerRPCCredentials = append(o.copts.PerRPCCredentials, creds)
} }
} }
// WithTimeout returns a DialOption that configures a timeout for dialing a client connection. // WithTimeout returns a DialOption that configures a timeout for dialing a ClientConn
// initially. This is valid if and only if WithBlock() is present.
func WithTimeout(d time.Duration) DialOption { func WithTimeout(d time.Duration) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.copts.Timeout = d o.timeout = d
} }
} }
@ -201,6 +213,7 @@ func WithUserAgent(s string) DialOption {
func Dial(target string, opts ...DialOption) (*ClientConn, error) { func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc := &ClientConn{ cc := &ClientConn{
target: target, target: target,
conns: make(map[Address]*addrConn),
} }
for _, opt := range opts { for _, opt := range opts {
opt(&cc.dopts) opt(&cc.dopts)
@ -214,13 +227,53 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc.dopts.bs = DefaultBackoffConfig cc.dopts.bs = DefaultBackoffConfig
} }
if cc.dopts.picker == nil { cc.balancer = cc.dopts.balancer
cc.dopts.picker = &unicastPicker{ if cc.balancer == nil {
target: target, cc.balancer = RoundRobin(nil)
}
if err := cc.balancer.Start(target); err != nil {
return nil, err
}
var (
ok bool
addrs []Address
)
ch := cc.balancer.Notify()
if ch == nil {
// There is no name resolver installed.
addrs = append(addrs, Address{Addr: target})
} else {
addrs, ok = <-ch
if !ok || len(addrs) == 0 {
return nil, errNoAddr
} }
} }
if err := cc.dopts.picker.Init(cc); err != nil { waitC := make(chan error, 1)
return nil, err go func() {
for _, a := range addrs {
if err := cc.newAddrConn(a, false); err != nil {
waitC <- err
return
}
}
close(waitC)
}()
var timeoutCh <-chan time.Time
if cc.dopts.timeout > 0 {
timeoutCh = time.After(cc.dopts.timeout)
}
select {
case err := <-waitC:
if err != nil {
cc.Close()
return nil, err
}
case <-timeoutCh:
cc.Close()
return nil, ErrClientConnTimeout
}
if ok {
go cc.lbWatcher()
} }
colonPos := strings.LastIndex(target, ":") colonPos := strings.LastIndex(target, ":")
if colonPos == -1 { if colonPos == -1 {
@ -263,325 +316,358 @@ func (s ConnectivityState) String() string {
} }
} }
// ClientConn represents a client connection to an RPC service. // ClientConn represents a client connection to an RPC server.
type ClientConn struct { type ClientConn struct {
target string target string
balancer Balancer
authority string authority string
dopts dialOptions dopts dialOptions
mu sync.RWMutex
conns map[Address]*addrConn
} }
// State returns the connectivity state of cc. func (cc *ClientConn) lbWatcher() {
// This is EXPERIMENTAL API. for addrs := range cc.balancer.Notify() {
func (cc *ClientConn) State() (ConnectivityState, error) { var (
return cc.dopts.picker.State() add []Address // Addresses need to setup connections.
del []*addrConn // Connections need to tear down.
)
cc.mu.Lock()
for _, a := range addrs {
if _, ok := cc.conns[a]; !ok {
add = append(add, a)
}
}
for k, c := range cc.conns {
var keep bool
for _, a := range addrs {
if k == a {
keep = true
break
}
}
if !keep {
del = append(del, c)
}
}
cc.mu.Unlock()
for _, a := range add {
cc.newAddrConn(a, true)
}
for _, c := range del {
c.tearDown(errConnDrain)
}
}
} }
// WaitForStateChange blocks until the state changes to something other than the sourceState. func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
// It returns the new state or error. ac := &addrConn{
// This is EXPERIMENTAL API. cc: cc,
func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) { addr: addr,
return cc.dopts.picker.WaitForStateChange(ctx, sourceState) dopts: cc.dopts,
shutdownChan: make(chan struct{}),
}
if EnableTracing {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
}
if !ac.dopts.insecure {
if ac.dopts.copts.TransportCredentials == nil {
return errNoTransportSecurity
}
} else {
if ac.dopts.copts.TransportCredentials != nil {
return errCredentialsConflict
}
for _, cd := range ac.dopts.copts.PerRPCCredentials {
if cd.RequireTransportSecurity() {
return errTransportCredentialsMissing
}
}
}
// Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called.
ac.cc.mu.Lock()
if ac.cc.conns == nil {
ac.cc.mu.Unlock()
return ErrClientConnClosing
}
stale := ac.cc.conns[ac.addr]
ac.cc.conns[ac.addr] = ac
ac.cc.mu.Unlock()
if stale != nil {
// There is an addrConn alive on ac.addr already. This could be due to
// i) stale's Close is undergoing;
// ii) a buggy Balancer notifies duplicated Addresses.
stale.tearDown(errConnDrain)
}
ac.stateCV = sync.NewCond(&ac.mu)
// skipWait may overwrite the decision in ac.dopts.block.
if ac.dopts.block && !skipWait {
if err := ac.resetTransport(false); err != nil {
ac.tearDown(err)
return err
}
// Start to monitor the error status of transport.
go ac.transportMonitor()
} else {
// Start a goroutine connecting to the server asynchronously.
go func() {
if err := ac.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
ac.tearDown(err)
return
}
ac.transportMonitor()
}()
}
return nil
} }
// Close starts to tear down the ClientConn. func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
// TODO(zhaoq): Implement fail-fast logic.
addr, put, err := cc.balancer.Get(ctx, opts)
if err != nil {
return nil, nil, err
}
cc.mu.RLock()
if cc.conns == nil {
cc.mu.RUnlock()
return nil, nil, ErrClientConnClosing
}
ac, ok := cc.conns[addr]
cc.mu.RUnlock()
if !ok {
if put != nil {
put()
}
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
}
t, err := ac.wait(ctx)
if err != nil {
if put != nil {
put()
}
return nil, nil, err
}
return t, put, nil
}
// Close tears down the ClientConn and all underlying connections.
func (cc *ClientConn) Close() error { func (cc *ClientConn) Close() error {
return cc.dopts.picker.Close() cc.mu.Lock()
if cc.conns == nil {
cc.mu.Unlock()
return ErrClientConnClosing
}
conns := cc.conns
cc.conns = nil
cc.mu.Unlock()
cc.balancer.Close()
for _, ac := range conns {
ac.tearDown(ErrClientConnClosing)
}
return nil
} }
// Conn is a client connection to a single destination. // addrConn is a network connection to a given address.
type Conn struct { type addrConn struct {
target string cc *ClientConn
addr Address
dopts dialOptions dopts dialOptions
resetChan chan int
shutdownChan chan struct{} shutdownChan chan struct{}
events trace.EventLog events trace.EventLog
mu sync.Mutex mu sync.Mutex
state ConnectivityState state ConnectivityState
stateCV *sync.Cond stateCV *sync.Cond
down func(error) // the handler called when a connection is down.
// ready is closed and becomes nil when a new transport is up or failed // ready is closed and becomes nil when a new transport is up or failed
// due to timeout. // due to timeout.
ready chan struct{} ready chan struct{}
transport transport.ClientTransport transport transport.ClientTransport
} }
// NewConn creates a Conn. // printf records an event in ac's event log, unless ac has been closed.
func NewConn(cc *ClientConn) (*Conn, error) { // REQUIRES ac.mu is held.
if cc.target == "" { func (ac *addrConn) printf(format string, a ...interface{}) {
return nil, ErrUnspecTarget if ac.events != nil {
} ac.events.Printf(format, a...)
c := &Conn{
target: cc.target,
dopts: cc.dopts,
resetChan: make(chan int, 1),
shutdownChan: make(chan struct{}),
}
if EnableTracing {
c.events = trace.NewEventLog("grpc.ClientConn", c.target)
}
if !c.dopts.insecure {
var ok bool
for _, cd := range c.dopts.copts.AuthOptions {
if _, ok = cd.(credentials.TransportAuthenticator); ok {
break
}
}
if !ok {
return nil, ErrNoTransportSecurity
}
} else {
for _, cd := range c.dopts.copts.AuthOptions {
if cd.RequireTransportSecurity() {
return nil, ErrCredentialsMisuse
}
}
}
c.stateCV = sync.NewCond(&c.mu)
if c.dopts.block {
if err := c.resetTransport(false); err != nil {
c.Close()
return nil, err
}
// Start to monitor the error status of transport.
go c.transportMonitor()
} else {
// Start a goroutine connecting to the server asynchronously.
go func() {
if err := c.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", c.target, err)
c.Close()
return
}
c.transportMonitor()
}()
}
return c, nil
}
// printf records an event in cc's event log, unless cc has been closed.
// REQUIRES cc.mu is held.
func (cc *Conn) printf(format string, a ...interface{}) {
if cc.events != nil {
cc.events.Printf(format, a...)
} }
} }
// errorf records an error in cc's event log, unless cc has been closed. // errorf records an error in ac's event log, unless ac has been closed.
// REQUIRES cc.mu is held. // REQUIRES ac.mu is held.
func (cc *Conn) errorf(format string, a ...interface{}) { func (ac *addrConn) errorf(format string, a ...interface{}) {
if cc.events != nil { if ac.events != nil {
cc.events.Errorf(format, a...) ac.events.Errorf(format, a...)
} }
} }
// State returns the connectivity state of the Conn // getState returns the connectivity state of the Conn
func (cc *Conn) State() ConnectivityState { func (ac *addrConn) getState() ConnectivityState {
cc.mu.Lock() ac.mu.Lock()
defer cc.mu.Unlock() defer ac.mu.Unlock()
return cc.state return ac.state
} }
// WaitForStateChange blocks until the state changes to something other than the sourceState. // waitForStateChange blocks until the state changes to something other than the sourceState.
func (cc *Conn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) { func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
cc.mu.Lock() ac.mu.Lock()
defer cc.mu.Unlock() defer ac.mu.Unlock()
if sourceState != cc.state { if sourceState != ac.state {
return cc.state, nil return ac.state, nil
} }
done := make(chan struct{}) done := make(chan struct{})
var err error var err error
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
cc.mu.Lock() ac.mu.Lock()
err = ctx.Err() err = ctx.Err()
cc.stateCV.Broadcast() ac.stateCV.Broadcast()
cc.mu.Unlock() ac.mu.Unlock()
case <-done: case <-done:
} }
}() }()
defer close(done) defer close(done)
for sourceState == cc.state { for sourceState == ac.state {
cc.stateCV.Wait() ac.stateCV.Wait()
if err != nil { if err != nil {
return cc.state, err return ac.state, err
} }
} }
return cc.state, nil return ac.state, nil
} }
// NotifyReset tries to signal the underlying transport needs to be reset due to func (ac *addrConn) resetTransport(closeTransport bool) error {
// for example a name resolution change in flight.
func (cc *Conn) NotifyReset() {
select {
case cc.resetChan <- 0:
default:
}
}
func (cc *Conn) resetTransport(closeTransport bool) error {
var retries int var retries int
start := time.Now()
for { for {
cc.mu.Lock() ac.mu.Lock()
cc.printf("connecting") ac.printf("connecting")
if cc.state == Shutdown { if ac.state == Shutdown {
// cc.Close() has been invoked. // ac.tearDown(...) has been invoked.
cc.mu.Unlock() ac.mu.Unlock()
return ErrClientConnClosing return errConnClosing
} }
cc.state = Connecting if ac.down != nil {
cc.stateCV.Broadcast() ac.down(downErrorf(false, true, "%v", errNetworkIO))
cc.mu.Unlock() ac.down = nil
if closeTransport {
cc.transport.Close()
} }
// Adjust timeout for the current try. ac.state = Connecting
copts := cc.dopts.copts ac.stateCV.Broadcast()
if copts.Timeout < 0 { t := ac.transport
cc.Close() ac.mu.Unlock()
return ErrClientConnTimeout if closeTransport && t != nil {
t.Close()
} }
if copts.Timeout > 0 { sleepTime := ac.dopts.bs.backoff(retries)
copts.Timeout -= time.Since(start) ac.dopts.copts.Timeout = sleepTime
if copts.Timeout <= 0 { if sleepTime < minConnectTimeout {
cc.Close() ac.dopts.copts.Timeout = minConnectTimeout
return ErrClientConnTimeout
}
}
sleepTime := cc.dopts.bs.backoff(retries)
timeout := sleepTime
if timeout < minConnectTimeout {
timeout = minConnectTimeout
}
if copts.Timeout == 0 || copts.Timeout > timeout {
copts.Timeout = timeout
} }
connectTime := time.Now() connectTime := time.Now()
addr, err := cc.dopts.picker.PickAddr() newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts)
var newTransport transport.ClientTransport
if err == nil {
newTransport, err = transport.NewClientTransport(addr, &copts)
}
if err != nil { if err != nil {
cc.mu.Lock() ac.mu.Lock()
if cc.state == Shutdown { if ac.state == Shutdown {
// cc.Close() has been invoked. // ac.tearDown(...) has been invoked.
cc.mu.Unlock() ac.mu.Unlock()
return ErrClientConnClosing return errConnClosing
} }
cc.errorf("transient failure: %v", err) ac.errorf("transient failure: %v", err)
cc.state = TransientFailure ac.state = TransientFailure
cc.stateCV.Broadcast() ac.stateCV.Broadcast()
if cc.ready != nil { if ac.ready != nil {
close(cc.ready) close(ac.ready)
cc.ready = nil ac.ready = nil
} }
cc.mu.Unlock() ac.mu.Unlock()
sleepTime -= time.Since(connectTime) sleepTime -= time.Since(connectTime)
if sleepTime < 0 { if sleepTime < 0 {
sleepTime = 0 sleepTime = 0
} }
// Fail early before falling into sleep.
if cc.dopts.copts.Timeout > 0 && cc.dopts.copts.Timeout < sleepTime+time.Since(start) {
cc.mu.Lock()
cc.errorf("connection timeout")
cc.mu.Unlock()
cc.Close()
return ErrClientConnTimeout
}
closeTransport = false closeTransport = false
time.Sleep(sleepTime) select {
case <-time.After(sleepTime):
case <-ac.shutdownChan:
}
retries++ retries++
grpclog.Printf("grpc: Conn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, cc.target) grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
continue continue
} }
cc.mu.Lock() ac.mu.Lock()
cc.printf("ready") ac.printf("ready")
if cc.state == Shutdown { if ac.state == Shutdown {
// cc.Close() has been invoked. // ac.tearDown(...) has been invoked.
cc.mu.Unlock() ac.mu.Unlock()
newTransport.Close() newTransport.Close()
return ErrClientConnClosing return errConnClosing
} }
cc.state = Ready ac.state = Ready
cc.stateCV.Broadcast() ac.stateCV.Broadcast()
cc.transport = newTransport ac.transport = newTransport
if cc.ready != nil { if ac.ready != nil {
close(cc.ready) close(ac.ready)
cc.ready = nil ac.ready = nil
} }
cc.mu.Unlock() ac.down = ac.cc.balancer.Up(ac.addr)
ac.mu.Unlock()
return nil return nil
} }
} }
func (cc *Conn) reconnect() bool {
cc.mu.Lock()
if cc.state == Shutdown {
// cc.Close() has been invoked.
cc.mu.Unlock()
return false
}
cc.state = TransientFailure
cc.stateCV.Broadcast()
cc.mu.Unlock()
if err := cc.resetTransport(true); err != nil {
// The ClientConn is closing.
cc.mu.Lock()
cc.printf("transport exiting: %v", err)
cc.mu.Unlock()
grpclog.Printf("grpc: Conn.transportMonitor exits due to: %v", err)
return false
}
return true
}
// Run in a goroutine to track the error in transport and create the // Run in a goroutine to track the error in transport and create the
// new transport if an error happens. It returns when the channel is closing. // new transport if an error happens. It returns when the channel is closing.
func (cc *Conn) transportMonitor() { func (ac *addrConn) transportMonitor() {
for { for {
ac.mu.Lock()
t := ac.transport
ac.mu.Unlock()
select { select {
// shutdownChan is needed to detect the teardown when // shutdownChan is needed to detect the teardown when
// the ClientConn is idle (i.e., no RPC in flight). // the addrConn is idle (i.e., no RPC in flight).
case <-cc.shutdownChan: case <-ac.shutdownChan:
return return
case <-cc.resetChan: case <-t.Error():
if !cc.reconnect() { ac.mu.Lock()
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
ac.mu.Unlock()
return return
} }
case <-cc.transport.Error(): ac.state = TransientFailure
if !cc.reconnect() { ac.stateCV.Broadcast()
ac.mu.Unlock()
if err := ac.resetTransport(true); err != nil {
ac.mu.Lock()
ac.printf("transport exiting: %v", err)
ac.mu.Unlock()
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
return return
} }
// Tries to drain reset signal if there is any since it is out-dated.
select {
case <-cc.resetChan:
default:
}
} }
} }
} }
// Wait blocks until i) the new transport is up or ii) ctx is done or iii) cc is closed. // wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed.
func (cc *Conn) Wait(ctx context.Context) (transport.ClientTransport, error) { func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) {
for { for {
cc.mu.Lock() ac.mu.Lock()
switch { switch {
case cc.state == Shutdown: case ac.state == Shutdown:
cc.mu.Unlock() ac.mu.Unlock()
return nil, ErrClientConnClosing return nil, errConnClosing
case cc.state == Ready: case ac.state == Ready:
ct := cc.transport ct := ac.transport
cc.mu.Unlock() ac.mu.Unlock()
return ct, nil return ct, nil
default: default:
ready := cc.ready ready := ac.ready
if ready == nil { if ready == nil {
ready = make(chan struct{}) ready = make(chan struct{})
cc.ready = ready ac.ready = ready
} }
cc.mu.Unlock() ac.mu.Unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, transport.ContextErr(ctx.Err()) return nil, transport.ContextErr(ctx.Err())
@ -592,32 +678,46 @@ func (cc *Conn) Wait(ctx context.Context) (transport.ClientTransport, error) {
} }
} }
// Close starts to tear down the Conn. Returns ErrClientConnClosing if // tearDown starts to tear down the addrConn.
// it has been closed (mostly due to dial time-out).
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
// some edge cases (e.g., the caller opens and closes many ClientConn's in a // some edge cases (e.g., the caller opens and closes many addrConn's in a
// tight loop. // tight loop.
func (cc *Conn) Close() error { func (ac *addrConn) tearDown(err error) {
cc.mu.Lock() ac.mu.Lock()
defer cc.mu.Unlock() defer func() {
if cc.state == Shutdown { ac.mu.Unlock()
return ErrClientConnClosing ac.cc.mu.Lock()
if ac.cc.conns != nil {
delete(ac.cc.conns, ac.addr)
}
ac.cc.mu.Unlock()
}()
if ac.state == Shutdown {
return
} }
cc.state = Shutdown ac.state = Shutdown
cc.stateCV.Broadcast() if ac.down != nil {
if cc.events != nil { ac.down(downErrorf(false, false, "%v", err))
cc.events.Finish() ac.down = nil
cc.events = nil
} }
if cc.ready != nil { ac.stateCV.Broadcast()
close(cc.ready) if ac.events != nil {
cc.ready = nil ac.events.Finish()
ac.events = nil
} }
if cc.transport != nil { if ac.ready != nil {
cc.transport.Close() close(ac.ready)
ac.ready = nil
} }
if cc.shutdownChan != nil { if ac.transport != nil {
close(cc.shutdownChan) if err == errConnDrain {
ac.transport.GracefulClose()
} else {
ac.transport.Close()
}
} }
return nil if ac.shutdownChan != nil {
close(ac.shutdownChan)
}
return
} }

View File

@ -54,9 +54,9 @@ var (
alpnProtoStr = []string{"h2"} alpnProtoStr = []string{"h2"}
) )
// Credentials defines the common interface all supported credentials must // PerRPCCredentials defines the common interface for the credentials which need to
// implement. // attach security information to every RPC (e.g., oauth2).
type Credentials interface { type PerRPCCredentials interface {
// GetRequestMetadata gets the current request metadata, refreshing // GetRequestMetadata gets the current request metadata, refreshing
// tokens if required. This should be called by the transport layer on // tokens if required. This should be called by the transport layer on
// each request, and the data should be populated in headers or other // each request, and the data should be populated in headers or other
@ -87,9 +87,9 @@ type AuthInfo interface {
AuthType() string AuthType() string
} }
// TransportAuthenticator defines the common interface for all the live gRPC wire // TransportCredentials defines the common interface for all the live gRPC wire
// protocols and supported transport security protocols (e.g., TLS, SSL). // protocols and supported transport security protocols (e.g., TLS, SSL).
type TransportAuthenticator interface { type TransportCredentials interface {
// ClientHandshake does the authentication handshake specified by the corresponding // ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn for clients. It returns the authenticated // authentication protocol on rawConn for clients. It returns the authenticated
// connection and the corresponding auth information about the connection. // connection and the corresponding auth information about the connection.
@ -98,9 +98,8 @@ type TransportAuthenticator interface {
// the authenticated connection and the corresponding auth information about // the authenticated connection and the corresponding auth information about
// the connection. // the connection.
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
// Info provides the ProtocolInfo of this TransportAuthenticator. // Info provides the ProtocolInfo of this TransportCredentials.
Info() ProtocolInfo Info() ProtocolInfo
Credentials
} }
// TLSInfo contains the auth information for a TLS authenticated connection. // TLSInfo contains the auth information for a TLS authenticated connection.
@ -109,6 +108,7 @@ type TLSInfo struct {
State tls.ConnectionState State tls.ConnectionState
} }
// AuthType returns the type of TLSInfo as a string.
func (t TLSInfo) AuthType() string { func (t TLSInfo) AuthType() string {
return "tls" return "tls"
} }
@ -185,20 +185,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
return conn, TLSInfo{conn.ConnectionState()}, nil return conn, TLSInfo{conn.ConnectionState()}, nil
} }
// NewTLS uses c to construct a TransportAuthenticator based on TLS. // NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportAuthenticator { func NewTLS(c *tls.Config) TransportCredentials {
tc := &tlsCreds{*c} tc := &tlsCreds{*c}
tc.config.NextProtos = alpnProtoStr tc.config.NextProtos = alpnProtoStr
return tc return tc
} }
// NewClientTLSFromCert constructs a TLS from the input certificate for client. // NewClientTLSFromCert constructs a TLS from the input certificate for client.
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator { func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials {
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}) return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
} }
// NewClientTLSFromFile constructs a TLS from the input certificate file for client. // NewClientTLSFromFile constructs a TLS from the input certificate file for client.
func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, error) { func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) {
b, err := ioutil.ReadFile(certFile) b, err := ioutil.ReadFile(certFile)
if err != nil { if err != nil {
return nil, err return nil, err
@ -211,13 +211,13 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator,
} }
// NewServerTLSFromCert constructs a TLS from the input certificate for server. // NewServerTLSFromCert constructs a TLS from the input certificate for server.
func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator { func NewServerTLSFromCert(cert *tls.Certificate) TransportCredentials {
return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}}) return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
} }
// NewServerTLSFromFile constructs a TLS from the input certificate file and key // NewServerTLSFromFile constructs a TLS from the input certificate file and key
// file for server. // file for server.
func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, error) { func NewServerTLSFromFile(certFile, keyFile string) (TransportCredentials, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile) cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -66,7 +66,8 @@ type Resolver interface {
// Watcher watches for the updates on the specified target. // Watcher watches for the updates on the specified target.
type Watcher interface { type Watcher interface {
// Next blocks until an update or error happens. It may return one or more // Next blocks until an update or error happens. It may return one or more
// updates. The first call should get the full set of the results. // updates. The first call should get the full set of the results. It should
// return an error if and only if Watcher cannot recover.
Next() ([]*Update, error) Next() ([]*Update, error)
// Close closes the Watcher. // Close closes the Watcher.
Close() Close()

View File

@ -1,243 +0,0 @@
/*
*
* Copyright 2014, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package grpc
import (
"container/list"
"fmt"
"sync"
"golang.org/x/net/context"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/naming"
"google.golang.org/grpc/transport"
)
// Picker picks a Conn for RPC requests.
// This is EXPERIMENTAL and please do not implement your own Picker for now.
type Picker interface {
// Init does initial processing for the Picker, e.g., initiate some connections.
Init(cc *ClientConn) error
// Pick blocks until either a transport.ClientTransport is ready for the upcoming RPC
// or some error happens.
Pick(ctx context.Context) (transport.ClientTransport, error)
// PickAddr picks a peer address for connecting. This will be called repeated for
// connecting/reconnecting.
PickAddr() (string, error)
// State returns the connectivity state of the underlying connections.
State() (ConnectivityState, error)
// WaitForStateChange blocks until the state changes to something other than
// the sourceState. It returns the new state or error.
WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error)
// Close closes all the Conn's owned by this Picker.
Close() error
}
// unicastPicker is the default Picker which is used when there is no custom Picker
// specified by users. It always picks the same Conn.
type unicastPicker struct {
target string
conn *Conn
}
func (p *unicastPicker) Init(cc *ClientConn) error {
c, err := NewConn(cc)
if err != nil {
return err
}
p.conn = c
return nil
}
func (p *unicastPicker) Pick(ctx context.Context) (transport.ClientTransport, error) {
return p.conn.Wait(ctx)
}
func (p *unicastPicker) PickAddr() (string, error) {
return p.target, nil
}
func (p *unicastPicker) State() (ConnectivityState, error) {
return p.conn.State(), nil
}
func (p *unicastPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
return p.conn.WaitForStateChange(ctx, sourceState)
}
func (p *unicastPicker) Close() error {
if p.conn != nil {
return p.conn.Close()
}
return nil
}
// unicastNamingPicker picks an address from a name resolver to set up the connection.
type unicastNamingPicker struct {
cc *ClientConn
resolver naming.Resolver
watcher naming.Watcher
mu sync.Mutex
// The list of the addresses are obtained from watcher.
addrs *list.List
// It tracks the current picked addr by PickAddr(). The next PickAddr may
// push it forward on addrs.
pickedAddr *list.Element
conn *Conn
}
// NewUnicastNamingPicker creates a Picker to pick addresses from a name resolver
// to connect.
func NewUnicastNamingPicker(r naming.Resolver) Picker {
return &unicastNamingPicker{
resolver: r,
addrs: list.New(),
}
}
type addrInfo struct {
addr string
// Set to true if this addrInfo needs to be deleted in the next PickAddrr() call.
deleting bool
}
// processUpdates calls Watcher.Next() once and processes the obtained updates.
func (p *unicastNamingPicker) processUpdates() error {
updates, err := p.watcher.Next()
if err != nil {
return err
}
for _, update := range updates {
switch update.Op {
case naming.Add:
p.mu.Lock()
p.addrs.PushBack(&addrInfo{
addr: update.Addr,
})
p.mu.Unlock()
// Initial connection setup
if p.conn == nil {
conn, err := NewConn(p.cc)
if err != nil {
return err
}
p.conn = conn
}
case naming.Delete:
p.mu.Lock()
for e := p.addrs.Front(); e != nil; e = e.Next() {
if update.Addr == e.Value.(*addrInfo).addr {
if e == p.pickedAddr {
// Do not remove the element now if it is the current picked
// one. We leave the deletion to the next PickAddr() call.
e.Value.(*addrInfo).deleting = true
// Notify Conn to close it. All the live RPCs on this connection
// will be aborted.
p.conn.NotifyReset()
} else {
p.addrs.Remove(e)
}
}
}
p.mu.Unlock()
default:
grpclog.Println("Unknown update.Op ", update.Op)
}
}
return nil
}
// monitor runs in a standalone goroutine to keep watching name resolution updates until the watcher
// is closed.
func (p *unicastNamingPicker) monitor() {
for {
if err := p.processUpdates(); err != nil {
return
}
}
}
func (p *unicastNamingPicker) Init(cc *ClientConn) error {
w, err := p.resolver.Resolve(cc.target)
if err != nil {
return err
}
p.watcher = w
p.cc = cc
// Get the initial name resolution.
if err := p.processUpdates(); err != nil {
return err
}
go p.monitor()
return nil
}
func (p *unicastNamingPicker) Pick(ctx context.Context) (transport.ClientTransport, error) {
return p.conn.Wait(ctx)
}
func (p *unicastNamingPicker) PickAddr() (string, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.pickedAddr == nil {
p.pickedAddr = p.addrs.Front()
} else {
pa := p.pickedAddr
p.pickedAddr = pa.Next()
if pa.Value.(*addrInfo).deleting {
p.addrs.Remove(pa)
}
if p.pickedAddr == nil {
p.pickedAddr = p.addrs.Front()
}
}
if p.pickedAddr == nil {
return "", fmt.Errorf("there is no address available to pick")
}
return p.pickedAddr.Value.(*addrInfo).addr, nil
}
func (p *unicastNamingPicker) State() (ConnectivityState, error) {
return 0, fmt.Errorf("State() is not supported for unicastNamingPicker")
}
func (p *unicastNamingPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
return 0, fmt.Errorf("WaitForStateChange is not supported for unicastNamingPciker")
}
func (p *unicastNamingPicker) Close() error {
p.watcher.Close()
p.conn.Close()
return nil
}

View File

@ -61,7 +61,7 @@ type Codec interface {
String() string String() string
} }
// protoCodec is a Codec implemetation with protobuf. It is the default codec for gRPC. // protoCodec is a Codec implementation with protobuf. It is the default codec for gRPC.
type protoCodec struct{} type protoCodec struct{}
func (protoCodec) Marshal(v interface{}) ([]byte, error) { func (protoCodec) Marshal(v interface{}) ([]byte, error) {
@ -187,7 +187,7 @@ const (
compressionMade compressionMade
) )
// parser reads complelete gRPC messages from the underlying reader. // parser reads complete gRPC messages from the underlying reader.
type parser struct { type parser struct {
// r is the underlying reader. // r is the underlying reader.
// See the comment on recvMsg for the permissible // See the comment on recvMsg for the permissible
@ -284,14 +284,11 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
switch pf { switch pf {
case compressionNone: case compressionNone:
case compressionMade: case compressionMade:
if recvCompress == "" {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: invalid grpc-encoding %q with compression enabled", recvCompress)
}
if dc == nil || recvCompress != dc.Type() { if dc == nil || recvCompress != dc.Type() {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) return transport.StreamErrorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
} }
default: default:
return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf) return transport.StreamErrorf(codes.Internal, "grpc: received unexpected payload format %d", pf)
} }
return nil return nil
} }

View File

@ -73,6 +73,7 @@ type ServiceDesc struct {
HandlerType interface{} HandlerType interface{}
Methods []MethodDesc Methods []MethodDesc
Streams []StreamDesc Streams []StreamDesc
Metadata interface{}
} }
// service consists of the information of the server serving this service and // service consists of the information of the server serving this service and
@ -95,10 +96,12 @@ type Server struct {
} }
type options struct { type options struct {
creds credentials.Credentials creds credentials.TransportCredentials
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor
maxConcurrentStreams uint32 maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server useHandlerImpl bool // use http.Handler-based server
} }
@ -113,12 +116,14 @@ func CustomCodec(codec Codec) ServerOption {
} }
} }
// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
func RPCCompressor(cp Compressor) ServerOption { func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) { return func(o *options) {
o.cp = cp o.cp = cp
} }
} }
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
func RPCDecompressor(dc Decompressor) ServerOption { func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) { return func(o *options) {
o.dc = dc o.dc = dc
@ -134,12 +139,35 @@ func MaxConcurrentStreams(n uint32) ServerOption {
} }
// Creds returns a ServerOption that sets credentials for server connections. // Creds returns a ServerOption that sets credentials for server connections.
func Creds(c credentials.Credentials) ServerOption { func Creds(c credentials.TransportCredentials) ServerOption {
return func(o *options) { return func(o *options) {
o.creds = c o.creds = c
} }
} }
// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the
// server. Only one unary interceptor can be installed. The construction of multiple
// interceptors (e.g., chaining) can be implemented at the caller.
func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
return func(o *options) {
if o.unaryInt != nil {
panic("The unary server interceptor has been set.")
}
o.unaryInt = i
}
}
// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the
// server. Only one stream interceptor can be installed.
func StreamInterceptor(i StreamServerInterceptor) ServerOption {
return func(o *options) {
if o.streamInt != nil {
panic("The stream server interceptor has been set.")
}
o.streamInt = i
}
}
// NewServer creates a gRPC server which has no service registered and has not // NewServer creates a gRPC server which has no service registered and has not
// started to accept requests yet. // started to accept requests yet.
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
@ -222,22 +250,23 @@ var (
) )
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
creds, ok := s.opts.creds.(credentials.TransportAuthenticator) if s.opts.creds == nil {
if !ok {
return rawConn, nil, nil return rawConn, nil, nil
} }
return creds.ServerHandshake(rawConn) return s.opts.creds.ServerHandshake(rawConn)
} }
// Serve accepts incoming connections on the listener lis, creating a new // Serve accepts incoming connections on the listener lis, creating a new
// ServerTransport and service goroutine for each. The service goroutines // ServerTransport and service goroutine for each. The service goroutines
// read gRPC requests and then call the registered handlers to reply to them. // read gRPC requests and then call the registered handlers to reply to them.
// Service returns when lis.Accept fails. // Service returns when lis.Accept fails. lis will be closed when
// this method returns.
func (s *Server) Serve(lis net.Listener) error { func (s *Server) Serve(lis net.Listener) error {
s.mu.Lock() s.mu.Lock()
s.printf("serving") s.printf("serving")
if s.lis == nil { if s.lis == nil {
s.mu.Unlock() s.mu.Unlock()
lis.Close()
return ErrServerStopped return ErrServerStopped
} }
s.lis[lis] = true s.lis[lis] = true
@ -435,6 +464,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
}() }()
} }
if s.opts.cp != nil {
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
stream.SetSendCompress(s.opts.cp.Type())
}
p := &parser{r: stream} p := &parser{r: stream}
for { for {
pf, req, err := p.recvMsg() pf, req, err := p.recvMsg()
@ -494,7 +527,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
return nil return nil
} }
reply, appErr := md.Handler(srv.server, stream.Context(), df, nil) reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
if appErr != nil { if appErr != nil {
if err, ok := appErr.(rpcError); ok { if err, ok := appErr.(rpcError); ok {
statusCode = err.code statusCode = err.code
@ -520,9 +553,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
Last: true, Last: true,
Delay: false, Delay: false,
} }
if s.opts.cp != nil {
stream.SetSendCompress(s.opts.cp.Type())
}
if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil { if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
switch err := err.(type) { switch err := err.(type) {
case transport.ConnectionError: case transport.ConnectionError:
@ -572,7 +602,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
ss.mu.Unlock() ss.mu.Unlock()
}() }()
} }
if appErr := sd.Handler(srv.server, ss); appErr != nil { var appErr error
if s.opts.streamInt == nil {
appErr = sd.Handler(srv.server, ss)
} else {
info := &StreamServerInfo{
FullMethod: stream.Method(),
IsClientStream: sd.ClientStreams,
IsServerStream: sd.ServerStreams,
}
appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
}
if appErr != nil {
if err, ok := appErr.(rpcError); ok { if err, ok := appErr.(rpcError); ok {
ss.statusCode = err.code ss.statusCode = err.code
ss.statusDesc = err.desc ss.statusDesc = err.desc

View File

@ -79,9 +79,9 @@ type Stream interface {
RecvMsg(m interface{}) error RecvMsg(m interface{}) error
} }
// ClientStream defines the interface a client stream has to satify. // ClientStream defines the interface a client stream has to satisfy.
type ClientStream interface { type ClientStream interface {
// Header returns the header metedata received from the server if there // Header returns the header metadata received from the server if there
// is any. It blocks if the metadata is not ready to read. // is any. It blocks if the metadata is not ready to read.
Header() (metadata.MD, error) Header() (metadata.MD, error)
// Trailer returns the trailer metadata from the server. It must be called // Trailer returns the trailer metadata from the server. It must be called
@ -103,12 +103,16 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
var ( var (
t transport.ClientTransport t transport.ClientTransport
err error err error
put func()
) )
t, err = cc.dopts.picker.Pick(ctx) // TODO(zhaoq): CallOption is omitted. Add support when it is needed.
gopts := BalancerGetOptions{
BlockingWait: false,
}
t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
return nil, toRPCErr(err) return nil, toRPCErr(err)
} }
// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
callHdr := &transport.CallHdr{ callHdr := &transport.CallHdr{
Host: cc.authority, Host: cc.authority,
Method: method, Method: method,
@ -119,6 +123,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
cs := &clientStream{ cs := &clientStream{
desc: desc, desc: desc,
put: put,
codec: cc.dopts.codec, codec: cc.dopts.codec,
cp: cc.dopts.cp, cp: cc.dopts.cp,
dc: cc.dopts.dc, dc: cc.dopts.dc,
@ -174,6 +179,7 @@ type clientStream struct {
tracing bool // set to EnableTracing when the clientStream is created. tracing bool // set to EnableTracing when the clientStream is created.
mu sync.Mutex mu sync.Mutex
put func()
closed bool closed bool
// trInfo.tr is set when the clientStream is created (if EnableTracing is true), // trInfo.tr is set when the clientStream is created (if EnableTracing is true),
// and is set to nil when the clientStream's finish method is called. // and is set to nil when the clientStream's finish method is called.
@ -311,6 +317,10 @@ func (cs *clientStream) finish(err error) {
} }
cs.mu.Lock() cs.mu.Lock()
defer cs.mu.Unlock() defer cs.mu.Unlock()
if cs.put != nil {
cs.put()
cs.put = nil
}
if cs.trInfo.tr != nil { if cs.trInfo.tr != nil {
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
cs.trInfo.tr.LazyPrintf("RPC: [OK]") cs.trInfo.tr.LazyPrintf("RPC: [OK]")

View File

@ -101,9 +101,8 @@ type payload struct {
func (p payload) String() string { func (p payload) String() string {
if p.sent { if p.sent {
return fmt.Sprintf("sent: %v", p.msg) return fmt.Sprintf("sent: %v", p.msg)
} else {
return fmt.Sprintf("recv: %v", p.msg)
} }
return fmt.Sprintf("recv: %v", p.msg)
} }
type fmtStringer struct { type fmtStringer struct {

View File

@ -65,7 +65,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
if r.Method != "POST" { if r.Method != "POST" {
return nil, errors.New("invalid gRPC request method") return nil, errors.New("invalid gRPC request method")
} }
if !strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { if !validContentType(r.Header.Get("Content-Type")) {
return nil, errors.New("invalid gRPC request content-type") return nil, errors.New("invalid gRPC request content-type")
} }
if _, ok := w.(http.Flusher); !ok { if _, ok := w.(http.Flusher); !ok {
@ -92,9 +92,12 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
} }
var metakv []string var metakv []string
if r.Host != "" {
metakv = append(metakv, ":authority", r.Host)
}
for k, vv := range r.Header { for k, vv := range r.Header {
k = strings.ToLower(k) k = strings.ToLower(k)
if isReservedHeader(k) { if isReservedHeader(k) && !isWhitelistedPseudoHeader(k) {
continue continue
} }
for _, v := range vv { for _, v := range vv {
@ -108,7 +111,6 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
} }
} }
metakv = append(metakv, k, v) metakv = append(metakv, k, v)
} }
} }
st.headerMD = metadata.Pairs(metakv...) st.headerMD = metadata.Pairs(metakv...)
@ -196,6 +198,10 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
} }
if md := s.Trailer(); len(md) > 0 { if md := s.Trailer(); len(md) > 0 {
for k, vv := range md { for k, vv := range md {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
if isReservedHeader(k) {
continue
}
for _, v := range vv { for _, v := range vv {
// http2 ResponseWriter mechanism to // http2 ResponseWriter mechanism to
// send undeclared Trailers after the // send undeclared Trailers after the
@ -249,6 +255,10 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
ht.writeCommonHeaders(s) ht.writeCommonHeaders(s)
h := ht.rw.Header() h := ht.rw.Header()
for k, vv := range md { for k, vv := range md {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
if isReservedHeader(k) {
continue
}
for _, v := range vv { for _, v := range vv {
h.Add(k, v) h.Add(k, v)
} }

View File

@ -35,7 +35,6 @@ package transport
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"math" "math"
"net" "net"
@ -89,7 +88,7 @@ type http2Client struct {
// The scheme used: https if TLS is on, http otherwise. // The scheme used: https if TLS is on, http otherwise.
scheme string scheme string
authCreds []credentials.Credentials creds []credentials.PerRPCCredentials
mu sync.Mutex // guard the following variables mu sync.Mutex // guard the following variables
state transportState // the state of underlying connection state transportState // the state of underlying connection
@ -118,19 +117,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf("transport: %v", connErr)
} }
var authInfo credentials.AuthInfo var authInfo credentials.AuthInfo
for _, c := range opts.AuthOptions { if opts.TransportCredentials != nil {
if ccreds, ok := c.(credentials.TransportAuthenticator); ok { scheme = "https"
scheme = "https" if timeout > 0 {
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are timeout -= time.Since(startT)
// multiple ones provided. Revisit this if it is not appropriate. Probably
// place the ClientTransport construction into a separate function to make
// things clear.
if timeout > 0 {
timeout -= time.Since(startT)
}
conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout)
break
} }
conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
} }
if connErr != nil { if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf("transport: %v", connErr)
@ -140,29 +132,6 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
conn.Close() conn.Close()
} }
}() }()
// Send connection preface to server.
n, err := conn.Write(clientPreface)
if err != nil {
return nil, ConnectionErrorf("transport: %v", err)
}
if n != len(clientPreface) {
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
}
framer := newFramer(conn)
if initialWindowSize != defaultWindowSize {
err = framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)})
} else {
err = framer.writeSettings(true)
}
if err != nil {
return nil, ConnectionErrorf("transport: %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
return nil, ConnectionErrorf("transport: %v", err)
}
}
ua := primaryUA ua := primaryUA
if opts.UserAgent != "" { if opts.UserAgent != "" {
ua = opts.UserAgent + " " + ua ua = opts.UserAgent + " " + ua
@ -178,7 +147,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
writableChan: make(chan int, 1), writableChan: make(chan int, 1),
shutdownChan: make(chan struct{}), shutdownChan: make(chan struct{}),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
framer: framer, framer: newFramer(conn),
hBuf: &buf, hBuf: &buf,
hEnc: hpack.NewEncoder(&buf), hEnc: hpack.NewEncoder(&buf),
controlBuf: newRecvBuffer(), controlBuf: newRecvBuffer(),
@ -187,17 +156,42 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
scheme: scheme, scheme: scheme,
state: reachable, state: reachable,
activeStreams: make(map[uint32]*Stream), activeStreams: make(map[uint32]*Stream),
authCreds: opts.AuthOptions, creds: opts.PerRPCCredentials,
maxStreams: math.MaxInt32, maxStreams: math.MaxInt32,
streamSendQuota: defaultWindowSize, streamSendQuota: defaultWindowSize,
} }
// Start the reader goroutine for incoming message. Each transport has
// a dedicated goroutine which reads HTTP2 frame from network. Then it
// dispatches the frame to the corresponding stream entity.
go t.reader()
// Send connection preface to server.
n, err := t.conn.Write(clientPreface)
if err != nil {
t.Close()
return nil, ConnectionErrorf("transport: %v", err)
}
if n != len(clientPreface) {
t.Close()
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
}
if initialWindowSize != defaultWindowSize {
err = t.framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)})
} else {
err = t.framer.writeSettings(true)
}
if err != nil {
t.Close()
return nil, ConnectionErrorf("transport: %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close()
return nil, ConnectionErrorf("transport: %v", err)
}
}
go t.controller() go t.controller()
t.writableChan <- 0 t.writableChan <- 0
// Start the reader goroutine for incoming message. The threading model
// on receiving is that each transport has a dedicated goroutine which
// reads HTTP2 frame from network. Then it dispatches the frame to the
// corresponding stream entity.
go t.reader()
return t, nil return t, nil
} }
@ -247,7 +241,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
ctx = peer.NewContext(ctx, pr) ctx = peer.NewContext(ctx, pr)
authData := make(map[string]string) authData := make(map[string]string)
for _, c := range t.authCreds { for _, c := range t.creds {
// Construct URI required to get auth request metadata. // Construct URI required to get auth request metadata.
var port string var port string
if pos := strings.LastIndex(t.target, ":"); pos != -1 { if pos := strings.LastIndex(t.target, ":"); pos != -1 {
@ -270,6 +264,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
} }
t.mu.Lock() t.mu.Lock()
if t.activeStreams == nil {
t.mu.Unlock()
return nil, ErrConnClosing
}
if t.state != reachable { if t.state != reachable {
t.mu.Unlock() t.mu.Unlock()
return nil, ErrConnClosing return nil, ErrConnClosing
@ -287,7 +285,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
} }
if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
// t.streamsQuota will be updated when t.CloseStream is invoked. // Return the quota back now because there is no stream returned to the caller.
if _, ok := err.(StreamError); ok && checkStreamsQuota {
t.streamsQuota.add(1)
}
return nil, err return nil, err
} }
t.mu.Lock() t.mu.Lock()
@ -339,6 +340,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if md, ok := metadata.FromContext(ctx); ok { if md, ok := metadata.FromContext(ctx); ok {
hasMD = true hasMD = true
for k, v := range md { for k, v := range md {
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
if isReservedHeader(k) {
continue
}
for _, entry := range v { for _, entry := range v {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
} }
@ -388,9 +393,19 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
func (t *http2Client) CloseStream(s *Stream, err error) { func (t *http2Client) CloseStream(s *Stream, err error) {
var updateStreams bool var updateStreams bool
t.mu.Lock() t.mu.Lock()
if t.activeStreams == nil {
t.mu.Unlock()
return
}
if t.streamsQuota != nil { if t.streamsQuota != nil {
updateStreams = true updateStreams = true
} }
if t.state == draining && len(t.activeStreams) == 1 {
// The transport is draining and s is the last live stream on t.
t.mu.Unlock()
t.Close()
return
}
delete(t.activeStreams, s.id) delete(t.activeStreams, s.id)
t.mu.Unlock() t.mu.Unlock()
if updateStreams { if updateStreams {
@ -427,9 +442,12 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more. // accessed any more.
func (t *http2Client) Close() (err error) { func (t *http2Client) Close() (err error) {
t.mu.Lock() t.mu.Lock()
if t.state == reachable {
close(t.errorChan)
}
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return errors.New("transport: Close() was already called") return
} }
t.state = closing t.state = closing
t.mu.Unlock() t.mu.Unlock()
@ -452,6 +470,25 @@ func (t *http2Client) Close() (err error) {
return return
} }
func (t *http2Client) GracefulClose() error {
t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
return nil
}
if t.state == draining {
t.mu.Unlock()
return nil
}
t.state = draining
active := len(t.activeStreams)
t.mu.Unlock()
if active == 0 {
return t.Close()
}
return nil
}
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller // Write formats the data into HTTP2 data frame(s) and sends it out. The caller
// should proceed only if Write returns nil. // should proceed only if Write returns nil.
// TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later // TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later
@ -574,6 +611,11 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
// Window updates will deliver to the controller for sending when // Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold. // the cumulative quota exceeds the corresponding threshold.
func (t *http2Client) updateWindow(s *Stream, n uint32) { func (t *http2Client) updateWindow(s *Stream, n uint32) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state == streamDone {
return
}
if w := t.fc.onRead(n); w > 0 { if w := t.fc.onRead(n); w > 0 {
t.controlBuf.put(&windowUpdate{0, w}) t.controlBuf.put(&windowUpdate{0, w})
} }

View File

@ -303,6 +303,11 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
// Window updates will deliver to the controller for sending when // Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold. // the cumulative quota exceeds the corresponding threshold.
func (t *http2Server) updateWindow(s *Stream, n uint32) { func (t *http2Server) updateWindow(s *Stream, n uint32) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state == streamDone {
return
}
if w := t.fc.onRead(n); w > 0 { if w := t.fc.onRead(n); w > 0 {
t.controlBuf.put(&windowUpdate{0, w}) t.controlBuf.put(&windowUpdate{0, w})
} }
@ -455,6 +460,10 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
} }
for k, v := range md { for k, v := range md {
if isReservedHeader(k) {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
continue
}
for _, entry := range v { for _, entry := range v {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
} }
@ -497,6 +506,10 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc})
// Attach the trailer metadata. // Attach the trailer metadata.
for k, v := range s.trailer { for k, v := range s.trailer {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
if isReservedHeader(k) {
continue
}
for _, entry := range v { for _, entry := range v {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
} }

View File

@ -127,16 +127,40 @@ func isReservedHeader(hdr string) bool {
} }
} }
// isWhitelistedPseudoHeader checks whether hdr belongs to HTTP2 pseudoheaders
// that should be propagated into metadata visible to users.
func isWhitelistedPseudoHeader(hdr string) bool {
switch hdr {
case ":authority":
return true
default:
return false
}
}
func (d *decodeState) setErr(err error) { func (d *decodeState) setErr(err error) {
if d.err == nil { if d.err == nil {
d.err = err d.err = err
} }
} }
func validContentType(t string) bool {
e := "application/grpc"
if !strings.HasPrefix(t, e) {
return false
}
// Support variations on the content-type
// (e.g. "application/grpc+blah", "application/grpc;blah").
if len(t) > len(e) && t[len(e)] != '+' && t[len(e)] != ';' {
return false
}
return true
}
func (d *decodeState) processHeaderField(f hpack.HeaderField) { func (d *decodeState) processHeaderField(f hpack.HeaderField) {
switch f.Name { switch f.Name {
case "content-type": case "content-type":
if !strings.Contains(f.Value, "application/grpc") { if !validContentType(f.Value) {
d.setErr(StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value)) d.setErr(StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value))
return return
} }
@ -162,7 +186,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
case ":path": case ":path":
d.method = f.Value d.method = f.Value
default: default:
if !isReservedHeader(f.Name) { if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) {
if f.Name == "user-agent" { if f.Name == "user-agent" {
i := strings.LastIndex(f.Value, " ") i := strings.LastIndex(f.Value, " ")
if i == -1 { if i == -1 {

View File

@ -321,6 +321,7 @@ const (
reachable transportState = iota reachable transportState = iota
unreachable unreachable
closing closing
draining
) )
// NewServerTransport creates a ServerTransport with conn or non-nil error // NewServerTransport creates a ServerTransport with conn or non-nil error
@ -335,9 +336,11 @@ type ConnectOptions struct {
UserAgent string UserAgent string
// Dialer specifies how to dial a network address. // Dialer specifies how to dial a network address.
Dialer func(string, time.Duration) (net.Conn, error) Dialer func(string, time.Duration) (net.Conn, error)
// AuthOptions stores the credentials required to setup a client connection and/or issue RPCs. // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
AuthOptions []credentials.Credentials PerRPCCredentials []credentials.PerRPCCredentials
// Timeout specifies the timeout for dialing a client connection. // TransportCredentials stores the Authenticator required to setup a client connection.
TransportCredentials credentials.TransportCredentials
// Timeout specifies the timeout for dialing a ClientTransport.
Timeout time.Duration Timeout time.Duration
} }
@ -391,6 +394,10 @@ type ClientTransport interface {
// is called only once. // is called only once.
Close() error Close() error
// GracefulClose starts to tear down the transport. It stops accepting
// new RPCs and wait the completion of the pending RPCs.
GracefulClose() error
// Write sends the data for the given stream. A nil stream indicates // Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole. // the write is to be performed on the transport as a whole.
Write(s *Stream, data []byte, opts *Options) error Write(s *Stream, data []byte, opts *Options) error
@ -468,7 +475,7 @@ func (e ConnectionError) Error() string {
return fmt.Sprintf("connection error: desc = %q", e.Desc) return fmt.Sprintf("connection error: desc = %q", e.Desc)
} }
// Define some common ConnectionErrors. // ErrConnClosing indicates that the transport is closing.
var ErrConnClosing = ConnectionError{Desc: "transport is closing"} var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
// StreamError is an error that only affects one stream within a connection. // StreamError is an error that only affects one stream within a connection.