Allow gRPC clients to connect to multiple backends (#1918)
Fixes #1917 and #1755, also updates google.golang.org/grpc to b60d3e9e.
This commit is contained in:
parent
f04b922aff
commit
92e0704b1b
|
|
@ -200,39 +200,39 @@
|
|||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/grpc",
|
||||
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
|
||||
"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/grpc/codes",
|
||||
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
|
||||
"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/grpc/credentials",
|
||||
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
|
||||
"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/grpc/grpclog",
|
||||
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
|
||||
"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/grpc/internal",
|
||||
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
|
||||
"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/grpc/metadata",
|
||||
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
|
||||
"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/grpc/naming",
|
||||
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
|
||||
"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/grpc/peer",
|
||||
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
|
||||
"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/grpc/transport",
|
||||
"Rev": "dd828651e45229541896bc41cd9cf2f89ac7002a"
|
||||
"Rev": "88aeffff979aa77aa502cb011423d0a08fa12c5a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "gopkg.in/gorp.v1",
|
||||
|
|
|
|||
|
|
@ -13,14 +13,14 @@ import (
|
|||
)
|
||||
|
||||
func main() {
|
||||
addr := flag.String("addr", "127.0.0.1:9090", "CCS address")
|
||||
addr := flag.String("addr", "boulder:9090", "CCS address")
|
||||
name := flag.String("name", "", "Name to check")
|
||||
issuer := flag.String("issuerDomain", "", "Issuer domain to check against")
|
||||
flag.Parse()
|
||||
|
||||
// Set up a connection to the server.
|
||||
conn, err := bgrpc.ClientSetup(&cmd.GRPCClientConfig{
|
||||
ServerAddress: *addr,
|
||||
ServerAddresses: []string{*addr},
|
||||
ServerIssuerPath: "test/grpc-creds/ca.pem",
|
||||
ClientCertificatePath: "test/grpc-creds/client.pem",
|
||||
ClientKeyPath: "test/grpc-creds/key.pem",
|
||||
|
|
|
|||
|
|
@ -511,7 +511,7 @@ type LogDescription struct {
|
|||
|
||||
// GRPCClientConfig contains the information needed to talk to the gRPC service
|
||||
type GRPCClientConfig struct {
|
||||
ServerAddress string
|
||||
ServerAddresses []string
|
||||
ServerIssuerPath string
|
||||
ClientCertificatePath string
|
||||
ClientKeyPath string
|
||||
|
|
|
|||
|
|
@ -0,0 +1,47 @@
|
|||
package grpc
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc/naming"
|
||||
)
|
||||
|
||||
// staticResolver implements both the naming.Resolver and naming.Watcher
|
||||
// interfaces. It always returns a single static list then blocks forever
|
||||
type staticResolver struct {
|
||||
addresses []*naming.Update
|
||||
}
|
||||
|
||||
func newStaticResolver(addresses []string) *staticResolver {
|
||||
sr := &staticResolver{}
|
||||
for _, a := range addresses {
|
||||
sr.addresses = append(sr.addresses, &naming.Update{
|
||||
Op: naming.Add,
|
||||
Addr: a,
|
||||
})
|
||||
}
|
||||
return sr
|
||||
}
|
||||
|
||||
// Resolve just returns the staticResolver it was called from as it satisfies
|
||||
// both the naming.Resolver and naming.Watcher interfaces
|
||||
func (sr *staticResolver) Resolve(target string) (naming.Watcher, error) {
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
// Next is called in a loop by grpc.RoundRobin expecting updates to which addresses are
|
||||
// appropriate. Since we just want to return a static list once return a list on the first
|
||||
// call then block forever on the second instead of sitting in a tight loop
|
||||
func (sr *staticResolver) Next() ([]*naming.Update, error) {
|
||||
if sr.addresses != nil {
|
||||
addrs := sr.addresses
|
||||
sr.addresses = nil
|
||||
return addrs, nil
|
||||
}
|
||||
// Since staticResolver.Next is called in a tight loop block forever
|
||||
// after returning the initial set of addresses
|
||||
forever := make(chan struct{})
|
||||
<-forever
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Close does nothing
|
||||
func (sr *staticResolver) Close() {}
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/naming"
|
||||
|
||||
"github.com/letsencrypt/boulder/test"
|
||||
)
|
||||
|
||||
func TestStaticResolver(t *testing.T) {
|
||||
names := []string{"test:443"}
|
||||
sr := newStaticResolver(names)
|
||||
watcher, err := sr.Resolve("")
|
||||
test.AssertNotError(t, err, "staticResolver.Resolve failed")
|
||||
|
||||
// Make sure doing this doesn't break anything (since it does nothing)
|
||||
watcher.Close()
|
||||
|
||||
updates, err := watcher.Next()
|
||||
test.AssertNotError(t, err, "staticwatcher.Next failed")
|
||||
test.AssertEquals(t, len(names), len(updates))
|
||||
test.AssertEquals(t, updates[0].Addr, "test:443")
|
||||
test.AssertEquals(t, updates[0].Op, naming.Add)
|
||||
test.AssertEquals(t, updates[0].Metadata, nil)
|
||||
|
||||
returned := make(chan struct{}, 1)
|
||||
go func() {
|
||||
_, err = watcher.Next()
|
||||
test.AssertNotError(t, err, "watcher.Next failed")
|
||||
returned <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-returned:
|
||||
t.Fatal("staticWatcher.Next returned something after the first call")
|
||||
case <-time.After(time.Millisecond * 500):
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
package creds
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
// transportCredentials is a grpc/credentials.TransportCredentials which supports
|
||||
// connecting to, and verifying multiple DNS names
|
||||
type transportCredentials struct {
|
||||
roots *x509.CertPool
|
||||
clients []tls.Certificate
|
||||
}
|
||||
|
||||
// New returns a new initialized grpc/credentials.TransportCredentials
|
||||
func New(rootCAs *x509.CertPool, clientCerts []tls.Certificate) credentials.TransportCredentials {
|
||||
return &transportCredentials{rootCAs, clientCerts}
|
||||
}
|
||||
|
||||
// ClientHandshake performs the TLS handshake for a client -> server connection
|
||||
func (tc *transportCredentials) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, credentials.AuthInfo, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
conn := tls.Client(rawConn, &tls.Config{
|
||||
ServerName: host,
|
||||
RootCAs: tc.roots,
|
||||
Certificates: tc.clients,
|
||||
MinVersion: tls.VersionTLS12, // Override default of tls.VersionTLS10
|
||||
MaxVersion: tls.VersionTLS12, // Same as default in golang <= 1.6
|
||||
})
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
errChan <- conn.Handshake()
|
||||
}()
|
||||
select {
|
||||
case <-time.After(timeout):
|
||||
return nil, nil, errors.New("boulder/grpc/creds: TLS handshake timed out")
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
_ = rawConn.Close()
|
||||
return nil, nil, fmt.Errorf("boulder/grpc/creds: TLS handshake failed: %s", err)
|
||||
}
|
||||
return conn, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ServerHandshake performs the TLS handshake for a server <- client connection
|
||||
func (tc *transportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return nil, nil, fmt.Errorf("boulder/grpc/creds: Server-side handshakes are not implemented")
|
||||
}
|
||||
|
||||
// Info returns information about the transport protocol used
|
||||
func (tc *transportCredentials) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{
|
||||
SecurityProtocol: "tls",
|
||||
SecurityVersion: "1.2", // We *only* support TLS 1.2
|
||||
}
|
||||
}
|
||||
|
||||
// GetRequestMetadata returns nil, nil since TLS credentials do not have metadata.
|
||||
func (tc *transportCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// RequireTransportSecurity always returns true because TLS is transport security
|
||||
func (tc *transportCredentials) RequireTransportSecurity() bool {
|
||||
return true
|
||||
}
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
package creds
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/letsencrypt/boulder/test"
|
||||
)
|
||||
|
||||
func TestTransportCredentials(t *testing.T) {
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
test.AssertNotError(t, err, "rsa.GenerateKey failed")
|
||||
|
||||
temp := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "A",
|
||||
},
|
||||
NotBefore: time.Unix(1000, 0),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
derA, err := x509.CreateCertificate(rand.Reader, temp, temp, priv.Public(), priv)
|
||||
test.AssertNotError(t, err, "x509.CreateCertificate failed")
|
||||
certA, err := x509.ParseCertificate(derA)
|
||||
test.AssertNotError(t, err, "x509.ParserCertificate failed")
|
||||
temp.Subject.CommonName = "B"
|
||||
derB, err := x509.CreateCertificate(rand.Reader, temp, temp, priv.Public(), priv)
|
||||
test.AssertNotError(t, err, "x509.CreateCertificate failed")
|
||||
certB, err := x509.ParseCertificate(derB)
|
||||
test.AssertNotError(t, err, "x509.ParserCertificate failed")
|
||||
roots := x509.NewCertPool()
|
||||
roots.AddCert(certA)
|
||||
roots.AddCert(certB)
|
||||
|
||||
serverA := httptest.NewUnstartedServer(nil)
|
||||
serverA.TLS = &tls.Config{Certificates: []tls.Certificate{{Certificate: [][]byte{derA}, PrivateKey: priv}}}
|
||||
serverB := httptest.NewUnstartedServer(nil)
|
||||
serverB.TLS = &tls.Config{Certificates: []tls.Certificate{{Certificate: [][]byte{derB}, PrivateKey: priv}}}
|
||||
|
||||
tc := New(roots, nil)
|
||||
|
||||
serverA.StartTLS()
|
||||
defer serverA.Close()
|
||||
addrA := serverA.Listener.Addr().String()
|
||||
rawConnA, err := net.Dial("tcp", addrA)
|
||||
test.AssertNotError(t, err, "net.Dial failed")
|
||||
defer func() {
|
||||
_ = rawConnA.Close()
|
||||
}()
|
||||
|
||||
conn, _, err := tc.ClientHandshake("A:2020", rawConnA, time.Second)
|
||||
test.AssertNotError(t, err, "tc.ClientHandshake failed")
|
||||
test.Assert(t, conn != nil, "tc.ClientHandshake returned a nil net.Conn")
|
||||
|
||||
serverB.StartTLS()
|
||||
defer serverB.Close()
|
||||
addrB := serverB.Listener.Addr().String()
|
||||
rawConnB, err := net.Dial("tcp", addrB)
|
||||
test.AssertNotError(t, err, "net.Dial failed")
|
||||
defer func() {
|
||||
_ = rawConnB.Close()
|
||||
}()
|
||||
|
||||
conn, _, err = tc.ClientHandshake("B:3030", rawConnB, time.Second)
|
||||
test.AssertNotError(t, err, "tc.ClientHandshake failed")
|
||||
test.Assert(t, conn != nil, "tc.ClientHandshake returned a nil net.Conn")
|
||||
|
||||
// Test timeout
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
test.AssertNotError(t, err, "net.Listen failed")
|
||||
defer func() {
|
||||
_ = ln.Close()
|
||||
}()
|
||||
addrC := ln.Addr().String()
|
||||
go func() {
|
||||
for {
|
||||
_, err := ln.Accept()
|
||||
test.AssertNotError(t, err, "ln.Accept failed")
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}()
|
||||
|
||||
rawConnC, err := net.Dial("tcp", addrC)
|
||||
test.AssertNotError(t, err, "net.Dial failed")
|
||||
defer func() {
|
||||
_ = rawConnB.Close()
|
||||
}()
|
||||
|
||||
conn, _, err = tc.ClientHandshake("A:2020", rawConnC, time.Millisecond)
|
||||
test.AssertError(t, err, "tc.ClientHandshake didn't timeout")
|
||||
test.AssertEquals(t, err.Error(), "boulder/grpc/creds: TLS handshake timed out")
|
||||
test.Assert(t, conn == nil, "tc.ClientHandshake returned a non-nil net.Conn on failure")
|
||||
}
|
||||
20
grpc/util.go
20
grpc/util.go
|
|
@ -12,17 +12,21 @@ import (
|
|||
"google.golang.org/grpc/credentials"
|
||||
|
||||
"github.com/letsencrypt/boulder/cmd"
|
||||
bcreds "github.com/letsencrypt/boulder/grpc/creds"
|
||||
)
|
||||
|
||||
// CodedError is a alias required to appease go vet
|
||||
var CodedError = grpc.Errorf
|
||||
|
||||
// ClientSetup loads various TLS certificates and creates a
|
||||
// gRPC TransportAuthenticator that presents the client certificate
|
||||
// gRPC TransportCredentials that presents the client certificate
|
||||
// and validates the certificate presented by the server is for a
|
||||
// specific hostname and issued by the provided issuer certificate
|
||||
// thens dials and returns a grpc.ClientConn to the remote service.
|
||||
func ClientSetup(c *cmd.GRPCClientConfig) (*grpc.ClientConn, error) {
|
||||
if len(c.ServerAddresses) == 0 {
|
||||
return nil, fmt.Errorf("boulder/grpc: ServerAddresses is empty")
|
||||
}
|
||||
serverIssuerBytes, err := ioutil.ReadFile(c.ServerIssuerPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -35,15 +39,11 @@ func ClientSetup(c *cmd.GRPCClientConfig) (*grpc.ClientConn, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
host, _, err := net.SplitHostPort(c.ServerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return grpc.Dial(c.ServerAddress, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
ServerName: host,
|
||||
RootCAs: rootCAs,
|
||||
Certificates: []tls.Certificate{clientCert},
|
||||
})))
|
||||
return grpc.Dial(
|
||||
"", // Since our staticResolver provides addresses we don't need to pass an address here
|
||||
grpc.WithTransportCredentials(bcreds.New(rootCAs, []tls.Certificate{clientCert})),
|
||||
grpc.WithBalancer(grpc.RoundRobin(newStaticResolver(c.ServerAddresses))),
|
||||
)
|
||||
}
|
||||
|
||||
// NewServer loads various TLS certificates and creates a
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@
|
|||
},
|
||||
"maxConcurrentRPCServerRequests": 16,
|
||||
"publisherService": {
|
||||
"serverAddress": "boulder:9091",
|
||||
"serverAddresses": ["boulder:9091"],
|
||||
"serverIssuerPath": "test/grpc-creds/ca.pem",
|
||||
"clientCertificatePath": "test/grpc-creds/client.pem",
|
||||
"clientKeyPath": "test/grpc-creds/key.pem",
|
||||
|
|
@ -173,7 +173,7 @@
|
|||
"doNotForceCN": true,
|
||||
"reuseValidAuthz": true,
|
||||
"vaService": {
|
||||
"serverAddress": "boulder:9092",
|
||||
"serverAddresses": ["boulder:9092"],
|
||||
"serverIssuerPath": "test/grpc-creds/ca.pem",
|
||||
"clientCertificatePath": "test/grpc-creds/client.pem",
|
||||
"clientKeyPath": "test/grpc-creds/key.pem",
|
||||
|
|
@ -224,7 +224,7 @@
|
|||
"dnsTries": 3,
|
||||
"issuerDomain": "happy-hacker-ca.invalid",
|
||||
"caaService": {
|
||||
"serverAddress": "boulder:9090",
|
||||
"serverAddresses": ["boulder:9090"],
|
||||
"serverIssuerPath": "test/grpc-creds/ca.pem",
|
||||
"clientCertificatePath": "test/grpc-creds/client.pem",
|
||||
"clientKeyPath": "test/grpc-creds/key.pem"
|
||||
|
|
@ -297,7 +297,7 @@
|
|||
"signFailureBackoffMax": "30m",
|
||||
"debugAddr": "localhost:8006",
|
||||
"publisher": {
|
||||
"serverAddress": "boulder:9091",
|
||||
"serverAddresses": ["boulder:9091"],
|
||||
"serverIssuerPath": "test/grpc-creds/ca.pem",
|
||||
"clientCertificatePath": "test/grpc-creds/client.pem",
|
||||
"clientKeyPath": "test/grpc-creds/key.pem",
|
||||
|
|
|
|||
|
|
@ -21,8 +21,9 @@ proto:
|
|||
exit 1; \
|
||||
fi
|
||||
go get -u -v github.com/golang/protobuf/protoc-gen-go
|
||||
for file in $$(git ls-files '*.proto'); do \
|
||||
protoc -I $$(dirname $$file) --go_out=plugins=grpc:$$(dirname $$file) $$file; \
|
||||
# use $$dir as the root for all proto files in the same directory
|
||||
for dir in $$(git ls-files '*.proto' | xargs -n1 dirname | uniq); do \
|
||||
protoc -I $$dir --go_out=plugins=grpc:$$dir $$dir/*.proto; \
|
||||
done
|
||||
|
||||
test: testdeps
|
||||
|
|
|
|||
|
|
@ -1,22 +1,22 @@
|
|||
Additional IP Rights Grant (Patents)
|
||||
|
||||
"This implementation" means the copyrightable works distributed by
|
||||
Google as part of the GRPC project.
|
||||
Google as part of the gRPC project.
|
||||
|
||||
Google hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||
no-charge, royalty-free, irrevocable (except as stated in this section)
|
||||
patent license to make, have made, use, offer to sell, sell, import,
|
||||
transfer and otherwise run, modify and propagate the contents of this
|
||||
implementation of GRPC, where such license applies only to those patent
|
||||
implementation of gRPC, where such license applies only to those patent
|
||||
claims, both currently owned or controlled by Google and acquired in
|
||||
the future, licensable by Google that are necessarily infringed by this
|
||||
implementation of GRPC. This grant does not include claims that would be
|
||||
implementation of gRPC. This grant does not include claims that would be
|
||||
infringed only as a consequence of further modification of this
|
||||
implementation. If you or your agent or exclusive licensee institute or
|
||||
order or agree to the institution of patent litigation against any
|
||||
entity (including a cross-claim or counterclaim in a lawsuit) alleging
|
||||
that this implementation of GRPC or any code incorporated within this
|
||||
implementation of GRPC constitutes direct or contributory patent
|
||||
that this implementation of gRPC or any code incorporated within this
|
||||
implementation of gRPC constitutes direct or contributory patent
|
||||
infringement, or inducement of patent infringement, then any patent
|
||||
rights granted to you under this License for this implementation of GRPC
|
||||
rights granted to you under this License for this implementation of gRPC
|
||||
shall terminate as of the date such litigation is filed.
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ var (
|
|||
// backoffStrategy defines the methodology for backing off after a grpc
|
||||
// connection failure.
|
||||
//
|
||||
// This is unexported until the GRPC project decides whether or not to allow
|
||||
// This is unexported until the gRPC project decides whether or not to allow
|
||||
// alternative backoff strategies. Once a decision is made, this type and its
|
||||
// method may be exported.
|
||||
type backoffStrategy interface {
|
||||
|
|
@ -28,14 +28,14 @@ type backoffStrategy interface {
|
|||
backoff(retries int) time.Duration
|
||||
}
|
||||
|
||||
// BackoffConfig defines the parameters for the default GRPC backoff strategy.
|
||||
// BackoffConfig defines the parameters for the default gRPC backoff strategy.
|
||||
type BackoffConfig struct {
|
||||
// MaxDelay is the upper bound of backoff delay.
|
||||
MaxDelay time.Duration
|
||||
|
||||
// TODO(stevvooe): The following fields are not exported, as allowing
|
||||
// changes would violate the current GRPC specification for backoff. If
|
||||
// GRPC decides to allow more interesting backoff strategies, these fields
|
||||
// changes would violate the current gRPC specification for backoff. If
|
||||
// gRPC decides to allow more interesting backoff strategies, these fields
|
||||
// may be opened up in the future.
|
||||
|
||||
// baseDelay is the amount of time to wait before retrying after the first
|
||||
|
|
|
|||
|
|
@ -0,0 +1,340 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2016, Google Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
*
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following disclaimer
|
||||
* in the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* * Neither the name of Google Inc. nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
*/
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/naming"
|
||||
"google.golang.org/grpc/transport"
|
||||
)
|
||||
|
||||
// Address represents a server the client connects to.
|
||||
// This is the EXPERIMENTAL API and may be changed or extended in the future.
|
||||
type Address struct {
|
||||
// Addr is the server address on which a connection will be established.
|
||||
Addr string
|
||||
// Metadata is the information associated with Addr, which may be used
|
||||
// to make load balancing decision.
|
||||
Metadata interface{}
|
||||
}
|
||||
|
||||
// BalancerGetOptions configures a Get call.
|
||||
// This is the EXPERIMENTAL API and may be changed or extended in the future.
|
||||
type BalancerGetOptions struct {
|
||||
// BlockingWait specifies whether Get should block when there is no
|
||||
// connected address.
|
||||
BlockingWait bool
|
||||
}
|
||||
|
||||
// Balancer chooses network addresses for RPCs.
|
||||
// This is the EXPERIMENTAL API and may be changed or extended in the future.
|
||||
type Balancer interface {
|
||||
// Start does the initialization work to bootstrap a Balancer. For example,
|
||||
// this function may start the name resolution and watch the updates. It will
|
||||
// be called when dialing.
|
||||
Start(target string) error
|
||||
// Up informs the Balancer that gRPC has a connection to the server at
|
||||
// addr. It returns down which is called once the connection to addr gets
|
||||
// lost or closed.
|
||||
// TODO: It is not clear how to construct and take advantage the meaningful error
|
||||
// parameter for down. Need realistic demands to guide.
|
||||
Up(addr Address) (down func(error))
|
||||
// Get gets the address of a server for the RPC corresponding to ctx.
|
||||
// i) If it returns a connected address, gRPC internals issues the RPC on the
|
||||
// connection to this address;
|
||||
// ii) If it returns an address on which the connection is under construction
|
||||
// (initiated by Notify(...)) but not connected, gRPC internals
|
||||
// * fails RPC if the RPC is fail-fast and connection is in the TransientFailure or
|
||||
// Shutdown state;
|
||||
// or
|
||||
// * issues RPC on the connection otherwise.
|
||||
// iii) If it returns an address on which the connection does not exist, gRPC
|
||||
// internals treats it as an error and will fail the corresponding RPC.
|
||||
//
|
||||
// Therefore, the following is the recommended rule when writing a custom Balancer.
|
||||
// If opts.BlockingWait is true, it should return a connected address or
|
||||
// block if there is no connected address. It should respect the timeout or
|
||||
// cancellation of ctx when blocking. If opts.BlockingWait is false (for fail-fast
|
||||
// RPCs), it should return an address it has notified via Notify(...) immediately
|
||||
// instead of blocking.
|
||||
//
|
||||
// The function returns put which is called once the rpc has completed or failed.
|
||||
// put can collect and report RPC stats to a remote load balancer. gRPC internals
|
||||
// will try to call this again if err is non-nil (unless err is ErrClientConnClosing).
|
||||
//
|
||||
// TODO: Add other non-recoverable errors?
|
||||
Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error)
|
||||
// Notify returns a channel that is used by gRPC internals to watch the addresses
|
||||
// gRPC needs to connect. The addresses might be from a name resolver or remote
|
||||
// load balancer. gRPC internals will compare it with the existing connected
|
||||
// addresses. If the address Balancer notified is not in the existing connected
|
||||
// addresses, gRPC starts to connect the address. If an address in the existing
|
||||
// connected addresses is not in the notification list, the corresponding connection
|
||||
// is shutdown gracefully. Otherwise, there are no operations to take. Note that
|
||||
// the Address slice must be the full list of the Addresses which should be connected.
|
||||
// It is NOT delta.
|
||||
Notify() <-chan []Address
|
||||
// Close shuts down the balancer.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// downErr implements net.Error. It is constructed by gRPC internals and passed to the down
|
||||
// call of Balancer.
|
||||
type downErr struct {
|
||||
timeout bool
|
||||
temporary bool
|
||||
desc string
|
||||
}
|
||||
|
||||
func (e downErr) Error() string { return e.desc }
|
||||
func (e downErr) Timeout() bool { return e.timeout }
|
||||
func (e downErr) Temporary() bool { return e.temporary }
|
||||
|
||||
func downErrorf(timeout, temporary bool, format string, a ...interface{}) downErr {
|
||||
return downErr{
|
||||
timeout: timeout,
|
||||
temporary: temporary,
|
||||
desc: fmt.Sprintf(format, a...),
|
||||
}
|
||||
}
|
||||
|
||||
// RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch
|
||||
// the name resolution updates and updates the addresses available correspondingly.
|
||||
func RoundRobin(r naming.Resolver) Balancer {
|
||||
return &roundRobin{r: r}
|
||||
}
|
||||
|
||||
type roundRobin struct {
|
||||
r naming.Resolver
|
||||
w naming.Watcher
|
||||
open []Address // all the addresses the client should potentially connect
|
||||
mu sync.Mutex
|
||||
addrCh chan []Address // the channel to notify gRPC internals the list of addresses the client should connect to.
|
||||
connected []Address // all the connected addresses
|
||||
next int // index of the next address to return for Get()
|
||||
waitCh chan struct{} // the channel to block when there is no connected address available
|
||||
done bool // The Balancer is closed.
|
||||
}
|
||||
|
||||
func (rr *roundRobin) watchAddrUpdates() error {
|
||||
updates, err := rr.w.Next()
|
||||
if err != nil {
|
||||
grpclog.Println("grpc: the naming watcher stops working due to %v.", err)
|
||||
return err
|
||||
}
|
||||
rr.mu.Lock()
|
||||
defer rr.mu.Unlock()
|
||||
for _, update := range updates {
|
||||
addr := Address{
|
||||
Addr: update.Addr,
|
||||
}
|
||||
switch update.Op {
|
||||
case naming.Add:
|
||||
var exist bool
|
||||
for _, v := range rr.open {
|
||||
if addr == v {
|
||||
exist = true
|
||||
grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr)
|
||||
break
|
||||
}
|
||||
}
|
||||
if exist {
|
||||
continue
|
||||
}
|
||||
rr.open = append(rr.open, addr)
|
||||
case naming.Delete:
|
||||
for i, v := range rr.open {
|
||||
if v == addr {
|
||||
copy(rr.open[i:], rr.open[i+1:])
|
||||
rr.open = rr.open[:len(rr.open)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
grpclog.Println("Unknown update.Op ", update.Op)
|
||||
}
|
||||
}
|
||||
// Make a copy of rr.open and write it onto rr.addrCh so that gRPC internals gets notified.
|
||||
open := make([]Address, len(rr.open), len(rr.open))
|
||||
copy(open, rr.open)
|
||||
if rr.done {
|
||||
return ErrClientConnClosing
|
||||
}
|
||||
rr.addrCh <- open
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rr *roundRobin) Start(target string) error {
|
||||
if rr.r == nil {
|
||||
// If there is no name resolver installed, it is not needed to
|
||||
// do name resolution. In this case, rr.addrCh stays nil.
|
||||
return nil
|
||||
}
|
||||
w, err := rr.r.Resolve(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rr.w = w
|
||||
rr.addrCh = make(chan []Address)
|
||||
go func() {
|
||||
for {
|
||||
if err := rr.watchAddrUpdates(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Up appends addr to the end of rr.connected and sends notification if there
|
||||
// are pending Get() calls.
|
||||
func (rr *roundRobin) Up(addr Address) func(error) {
|
||||
rr.mu.Lock()
|
||||
defer rr.mu.Unlock()
|
||||
for _, a := range rr.connected {
|
||||
if a == addr {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
rr.connected = append(rr.connected, addr)
|
||||
if len(rr.connected) == 1 {
|
||||
// addr is only one available. Notify the Get() callers who are blocking.
|
||||
if rr.waitCh != nil {
|
||||
close(rr.waitCh)
|
||||
rr.waitCh = nil
|
||||
}
|
||||
}
|
||||
return func(err error) {
|
||||
rr.down(addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// down removes addr from rr.connected and moves the remaining addrs forward.
|
||||
func (rr *roundRobin) down(addr Address, err error) {
|
||||
rr.mu.Lock()
|
||||
defer rr.mu.Unlock()
|
||||
for i, a := range rr.connected {
|
||||
if a == addr {
|
||||
copy(rr.connected[i:], rr.connected[i+1:])
|
||||
rr.connected = rr.connected[:len(rr.connected)-1]
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the next addr in the rotation.
|
||||
func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) {
|
||||
var ch chan struct{}
|
||||
rr.mu.Lock()
|
||||
if rr.done {
|
||||
rr.mu.Unlock()
|
||||
err = ErrClientConnClosing
|
||||
return
|
||||
}
|
||||
if rr.next >= len(rr.connected) {
|
||||
rr.next = 0
|
||||
}
|
||||
if len(rr.connected) > 0 {
|
||||
addr = rr.connected[rr.next]
|
||||
rr.next++
|
||||
rr.mu.Unlock()
|
||||
return
|
||||
}
|
||||
// There is no address available. Wait on rr.waitCh.
|
||||
// TODO(zhaoq): Handle the case when opts.BlockingWait is false.
|
||||
if rr.waitCh == nil {
|
||||
ch = make(chan struct{})
|
||||
rr.waitCh = ch
|
||||
} else {
|
||||
ch = rr.waitCh
|
||||
}
|
||||
rr.mu.Unlock()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = transport.ContextErr(ctx.Err())
|
||||
return
|
||||
case <-ch:
|
||||
rr.mu.Lock()
|
||||
if rr.done {
|
||||
rr.mu.Unlock()
|
||||
err = ErrClientConnClosing
|
||||
return
|
||||
}
|
||||
if len(rr.connected) == 0 {
|
||||
// The newly added addr got removed by Down() again.
|
||||
if rr.waitCh == nil {
|
||||
ch = make(chan struct{})
|
||||
rr.waitCh = ch
|
||||
} else {
|
||||
ch = rr.waitCh
|
||||
}
|
||||
rr.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
if rr.next >= len(rr.connected) {
|
||||
rr.next = 0
|
||||
}
|
||||
addr = rr.connected[rr.next]
|
||||
rr.next++
|
||||
rr.mu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rr *roundRobin) Notify() <-chan []Address {
|
||||
return rr.addrCh
|
||||
}
|
||||
|
||||
func (rr *roundRobin) Close() error {
|
||||
rr.mu.Lock()
|
||||
defer rr.mu.Unlock()
|
||||
rr.done = true
|
||||
if rr.w != nil {
|
||||
rr.w.Close()
|
||||
}
|
||||
if rr.waitCh != nil {
|
||||
close(rr.waitCh)
|
||||
rr.waitCh = nil
|
||||
}
|
||||
if rr.addrCh != nil {
|
||||
close(rr.addrCh)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -132,19 +132,16 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||
Last: true,
|
||||
Delay: false,
|
||||
}
|
||||
var (
|
||||
lastErr error // record the error that happened
|
||||
)
|
||||
for {
|
||||
var (
|
||||
err error
|
||||
t transport.ClientTransport
|
||||
stream *transport.Stream
|
||||
// Record the put handler from Balancer.Get(...). It is called once the
|
||||
// RPC has completed or failed.
|
||||
put func()
|
||||
)
|
||||
// TODO(zhaoq): Need a formal spec of retry strategy for non-failfast rpcs.
|
||||
if lastErr != nil && c.failFast {
|
||||
return toRPCErr(lastErr)
|
||||
}
|
||||
// TODO(zhaoq): Need a formal spec of fail-fast.
|
||||
callHdr := &transport.CallHdr{
|
||||
Host: cc.authority,
|
||||
Method: method,
|
||||
|
|
@ -152,39 +149,66 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||
if cc.dopts.cp != nil {
|
||||
callHdr.SendCompress = cc.dopts.cp.Type()
|
||||
}
|
||||
t, err = cc.dopts.picker.Pick(ctx)
|
||||
gopts := BalancerGetOptions{
|
||||
BlockingWait: !c.failFast,
|
||||
}
|
||||
t, put, err = cc.getTransport(ctx, gopts)
|
||||
if err != nil {
|
||||
if lastErr != nil {
|
||||
// This was a retry; return the error from the last attempt.
|
||||
return toRPCErr(lastErr)
|
||||
// TODO(zhaoq): Probably revisit the error handling.
|
||||
if err == ErrClientConnClosing {
|
||||
return Errorf(codes.FailedPrecondition, "%v", err)
|
||||
}
|
||||
return toRPCErr(err)
|
||||
if _, ok := err.(transport.StreamError); ok {
|
||||
return toRPCErr(err)
|
||||
}
|
||||
if _, ok := err.(transport.ConnectionError); ok {
|
||||
if c.failFast {
|
||||
return toRPCErr(err)
|
||||
}
|
||||
}
|
||||
// All the remaining cases are treated as retryable.
|
||||
continue
|
||||
}
|
||||
if c.traceInfo.tr != nil {
|
||||
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
|
||||
}
|
||||
stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts)
|
||||
if err != nil {
|
||||
if _, ok := err.(transport.ConnectionError); ok {
|
||||
lastErr = err
|
||||
continue
|
||||
if put != nil {
|
||||
put()
|
||||
put = nil
|
||||
}
|
||||
if lastErr != nil {
|
||||
return toRPCErr(lastErr)
|
||||
if _, ok := err.(transport.ConnectionError); ok {
|
||||
if c.failFast {
|
||||
return toRPCErr(err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return toRPCErr(err)
|
||||
}
|
||||
// Receive the response
|
||||
lastErr = recvResponse(cc.dopts, t, &c, stream, reply)
|
||||
if _, ok := lastErr.(transport.ConnectionError); ok {
|
||||
continue
|
||||
err = recvResponse(cc.dopts, t, &c, stream, reply)
|
||||
if err != nil {
|
||||
if put != nil {
|
||||
put()
|
||||
put = nil
|
||||
}
|
||||
if _, ok := err.(transport.ConnectionError); ok {
|
||||
if c.failFast {
|
||||
return toRPCErr(err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
t.CloseStream(stream, err)
|
||||
return toRPCErr(err)
|
||||
}
|
||||
if c.traceInfo.tr != nil {
|
||||
c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true)
|
||||
}
|
||||
t.CloseStream(stream, lastErr)
|
||||
if lastErr != nil {
|
||||
return toRPCErr(lastErr)
|
||||
t.CloseStream(stream, nil)
|
||||
if put != nil {
|
||||
put()
|
||||
put = nil
|
||||
}
|
||||
return Errorf(stream.StatusCode(), "%s", stream.StatusDesc())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,28 +43,38 @@ import (
|
|||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/trace"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/transport"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUnspecTarget indicates that the target address is unspecified.
|
||||
ErrUnspecTarget = errors.New("grpc: target is unspecified")
|
||||
// ErrNoTransportSecurity indicates that there is no transport security
|
||||
// ErrClientConnClosing indicates that the operation is illegal because
|
||||
// the ClientConn is closing.
|
||||
ErrClientConnClosing = errors.New("grpc: the client connection is closing")
|
||||
// ErrClientConnTimeout indicates that the ClientConn cannot establish the
|
||||
// underlying connections within the specified timeout.
|
||||
ErrClientConnTimeout = errors.New("grpc: timed out when dialing")
|
||||
|
||||
// errNoTransportSecurity indicates that there is no transport security
|
||||
// being set for ClientConn. Users should either set one or explicitly
|
||||
// call WithInsecure DialOption to disable security.
|
||||
ErrNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)")
|
||||
// ErrCredentialsMisuse indicates that users want to transmit security information
|
||||
// (e.g., oauth2 token) which requires secure connection on an insecure
|
||||
errNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)")
|
||||
// errTransportCredentialsMissing indicates that users want to transmit security
|
||||
// information (e.g., oauth2 token) which requires secure connection on an insecure
|
||||
// connection.
|
||||
ErrCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)")
|
||||
// ErrClientConnClosing indicates that the operation is illegal because
|
||||
// the session is closing.
|
||||
ErrClientConnClosing = errors.New("grpc: the client connection is closing")
|
||||
// ErrClientConnTimeout indicates that the connection could not be
|
||||
// established or re-established within the specified timeout.
|
||||
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
|
||||
errTransportCredentialsMissing = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportCredentials() to set)")
|
||||
// errCredentialsConflict indicates that grpc.WithTransportCredentials()
|
||||
// and grpc.WithInsecure() are both called for a connection.
|
||||
errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
|
||||
// errNetworkIP indicates that the connection is down due to some network I/O error.
|
||||
errNetworkIO = errors.New("grpc: failed with network I/O error")
|
||||
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
|
||||
errConnDrain = errors.New("grpc: the connection is drained")
|
||||
// errConnClosing indicates that the connection is closing.
|
||||
errConnClosing = errors.New("grpc: the connection is closing")
|
||||
errNoAddr = errors.New("grpc: there is no address available to dial")
|
||||
// minimum time to give a connection to complete
|
||||
minConnectTimeout = 20 * time.Second
|
||||
)
|
||||
|
|
@ -76,9 +86,10 @@ type dialOptions struct {
|
|||
cp Compressor
|
||||
dc Decompressor
|
||||
bs backoffStrategy
|
||||
picker Picker
|
||||
balancer Balancer
|
||||
block bool
|
||||
insecure bool
|
||||
timeout time.Duration
|
||||
copts transport.ConnectOptions
|
||||
}
|
||||
|
||||
|
|
@ -108,10 +119,10 @@ func WithDecompressor(dc Decompressor) DialOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithPicker returns a DialOption which sets a picker for connection selection.
|
||||
func WithPicker(p Picker) DialOption {
|
||||
// WithBalancer returns a DialOption which sets a load balancer.
|
||||
func WithBalancer(b Balancer) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.picker = p
|
||||
o.balancer = b
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -136,7 +147,7 @@ func WithBackoffConfig(b BackoffConfig) DialOption {
|
|||
// withBackoff sets the backoff strategy used for retries after a
|
||||
// failed connection attempt.
|
||||
//
|
||||
// This can be exported if arbitrary backoff strategies are allowed by GRPC.
|
||||
// This can be exported if arbitrary backoff strategies are allowed by gRPC.
|
||||
func withBackoff(bs backoffStrategy) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.bs = bs
|
||||
|
|
@ -162,24 +173,25 @@ func WithInsecure() DialOption {
|
|||
|
||||
// WithTransportCredentials returns a DialOption which configures a
|
||||
// connection level security credentials (e.g., TLS/SSL).
|
||||
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
|
||||
func WithTransportCredentials(creds credentials.TransportCredentials) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
||||
o.copts.TransportCredentials = creds
|
||||
}
|
||||
}
|
||||
|
||||
// WithPerRPCCredentials returns a DialOption which sets
|
||||
// credentials which will place auth state on each outbound RPC.
|
||||
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
|
||||
func WithPerRPCCredentials(creds credentials.PerRPCCredentials) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
||||
o.copts.PerRPCCredentials = append(o.copts.PerRPCCredentials, creds)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout returns a DialOption that configures a timeout for dialing a client connection.
|
||||
// WithTimeout returns a DialOption that configures a timeout for dialing a ClientConn
|
||||
// initially. This is valid if and only if WithBlock() is present.
|
||||
func WithTimeout(d time.Duration) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.copts.Timeout = d
|
||||
o.timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -201,6 +213,7 @@ func WithUserAgent(s string) DialOption {
|
|||
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
||||
cc := &ClientConn{
|
||||
target: target,
|
||||
conns: make(map[Address]*addrConn),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&cc.dopts)
|
||||
|
|
@ -214,13 +227,53 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
|||
cc.dopts.bs = DefaultBackoffConfig
|
||||
}
|
||||
|
||||
if cc.dopts.picker == nil {
|
||||
cc.dopts.picker = &unicastPicker{
|
||||
target: target,
|
||||
cc.balancer = cc.dopts.balancer
|
||||
if cc.balancer == nil {
|
||||
cc.balancer = RoundRobin(nil)
|
||||
}
|
||||
if err := cc.balancer.Start(target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var (
|
||||
ok bool
|
||||
addrs []Address
|
||||
)
|
||||
ch := cc.balancer.Notify()
|
||||
if ch == nil {
|
||||
// There is no name resolver installed.
|
||||
addrs = append(addrs, Address{Addr: target})
|
||||
} else {
|
||||
addrs, ok = <-ch
|
||||
if !ok || len(addrs) == 0 {
|
||||
return nil, errNoAddr
|
||||
}
|
||||
}
|
||||
if err := cc.dopts.picker.Init(cc); err != nil {
|
||||
return nil, err
|
||||
waitC := make(chan error, 1)
|
||||
go func() {
|
||||
for _, a := range addrs {
|
||||
if err := cc.newAddrConn(a, false); err != nil {
|
||||
waitC <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
close(waitC)
|
||||
}()
|
||||
var timeoutCh <-chan time.Time
|
||||
if cc.dopts.timeout > 0 {
|
||||
timeoutCh = time.After(cc.dopts.timeout)
|
||||
}
|
||||
select {
|
||||
case err := <-waitC:
|
||||
if err != nil {
|
||||
cc.Close()
|
||||
return nil, err
|
||||
}
|
||||
case <-timeoutCh:
|
||||
cc.Close()
|
||||
return nil, ErrClientConnTimeout
|
||||
}
|
||||
if ok {
|
||||
go cc.lbWatcher()
|
||||
}
|
||||
colonPos := strings.LastIndex(target, ":")
|
||||
if colonPos == -1 {
|
||||
|
|
@ -263,325 +316,358 @@ func (s ConnectivityState) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
// ClientConn represents a client connection to an RPC service.
|
||||
// ClientConn represents a client connection to an RPC server.
|
||||
type ClientConn struct {
|
||||
target string
|
||||
balancer Balancer
|
||||
authority string
|
||||
dopts dialOptions
|
||||
|
||||
mu sync.RWMutex
|
||||
conns map[Address]*addrConn
|
||||
}
|
||||
|
||||
// State returns the connectivity state of cc.
|
||||
// This is EXPERIMENTAL API.
|
||||
func (cc *ClientConn) State() (ConnectivityState, error) {
|
||||
return cc.dopts.picker.State()
|
||||
func (cc *ClientConn) lbWatcher() {
|
||||
for addrs := range cc.balancer.Notify() {
|
||||
var (
|
||||
add []Address // Addresses need to setup connections.
|
||||
del []*addrConn // Connections need to tear down.
|
||||
)
|
||||
cc.mu.Lock()
|
||||
for _, a := range addrs {
|
||||
if _, ok := cc.conns[a]; !ok {
|
||||
add = append(add, a)
|
||||
}
|
||||
}
|
||||
for k, c := range cc.conns {
|
||||
var keep bool
|
||||
for _, a := range addrs {
|
||||
if k == a {
|
||||
keep = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !keep {
|
||||
del = append(del, c)
|
||||
}
|
||||
}
|
||||
cc.mu.Unlock()
|
||||
for _, a := range add {
|
||||
cc.newAddrConn(a, true)
|
||||
}
|
||||
for _, c := range del {
|
||||
c.tearDown(errConnDrain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForStateChange blocks until the state changes to something other than the sourceState.
|
||||
// It returns the new state or error.
|
||||
// This is EXPERIMENTAL API.
|
||||
func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
|
||||
return cc.dopts.picker.WaitForStateChange(ctx, sourceState)
|
||||
func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
|
||||
ac := &addrConn{
|
||||
cc: cc,
|
||||
addr: addr,
|
||||
dopts: cc.dopts,
|
||||
shutdownChan: make(chan struct{}),
|
||||
}
|
||||
if EnableTracing {
|
||||
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
|
||||
}
|
||||
if !ac.dopts.insecure {
|
||||
if ac.dopts.copts.TransportCredentials == nil {
|
||||
return errNoTransportSecurity
|
||||
}
|
||||
} else {
|
||||
if ac.dopts.copts.TransportCredentials != nil {
|
||||
return errCredentialsConflict
|
||||
}
|
||||
for _, cd := range ac.dopts.copts.PerRPCCredentials {
|
||||
if cd.RequireTransportSecurity() {
|
||||
return errTransportCredentialsMissing
|
||||
}
|
||||
}
|
||||
}
|
||||
// Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called.
|
||||
ac.cc.mu.Lock()
|
||||
if ac.cc.conns == nil {
|
||||
ac.cc.mu.Unlock()
|
||||
return ErrClientConnClosing
|
||||
}
|
||||
stale := ac.cc.conns[ac.addr]
|
||||
ac.cc.conns[ac.addr] = ac
|
||||
ac.cc.mu.Unlock()
|
||||
if stale != nil {
|
||||
// There is an addrConn alive on ac.addr already. This could be due to
|
||||
// i) stale's Close is undergoing;
|
||||
// ii) a buggy Balancer notifies duplicated Addresses.
|
||||
stale.tearDown(errConnDrain)
|
||||
}
|
||||
ac.stateCV = sync.NewCond(&ac.mu)
|
||||
// skipWait may overwrite the decision in ac.dopts.block.
|
||||
if ac.dopts.block && !skipWait {
|
||||
if err := ac.resetTransport(false); err != nil {
|
||||
ac.tearDown(err)
|
||||
return err
|
||||
}
|
||||
// Start to monitor the error status of transport.
|
||||
go ac.transportMonitor()
|
||||
} else {
|
||||
// Start a goroutine connecting to the server asynchronously.
|
||||
go func() {
|
||||
if err := ac.resetTransport(false); err != nil {
|
||||
grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
|
||||
ac.tearDown(err)
|
||||
return
|
||||
}
|
||||
ac.transportMonitor()
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close starts to tear down the ClientConn.
|
||||
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
|
||||
// TODO(zhaoq): Implement fail-fast logic.
|
||||
addr, put, err := cc.balancer.Get(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cc.mu.RLock()
|
||||
if cc.conns == nil {
|
||||
cc.mu.RUnlock()
|
||||
return nil, nil, ErrClientConnClosing
|
||||
}
|
||||
ac, ok := cc.conns[addr]
|
||||
cc.mu.RUnlock()
|
||||
if !ok {
|
||||
if put != nil {
|
||||
put()
|
||||
}
|
||||
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
|
||||
}
|
||||
t, err := ac.wait(ctx)
|
||||
if err != nil {
|
||||
if put != nil {
|
||||
put()
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
return t, put, nil
|
||||
}
|
||||
|
||||
// Close tears down the ClientConn and all underlying connections.
|
||||
func (cc *ClientConn) Close() error {
|
||||
return cc.dopts.picker.Close()
|
||||
cc.mu.Lock()
|
||||
if cc.conns == nil {
|
||||
cc.mu.Unlock()
|
||||
return ErrClientConnClosing
|
||||
}
|
||||
conns := cc.conns
|
||||
cc.conns = nil
|
||||
cc.mu.Unlock()
|
||||
cc.balancer.Close()
|
||||
for _, ac := range conns {
|
||||
ac.tearDown(ErrClientConnClosing)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Conn is a client connection to a single destination.
|
||||
type Conn struct {
|
||||
target string
|
||||
// addrConn is a network connection to a given address.
|
||||
type addrConn struct {
|
||||
cc *ClientConn
|
||||
addr Address
|
||||
dopts dialOptions
|
||||
resetChan chan int
|
||||
shutdownChan chan struct{}
|
||||
events trace.EventLog
|
||||
|
||||
mu sync.Mutex
|
||||
state ConnectivityState
|
||||
stateCV *sync.Cond
|
||||
down func(error) // the handler called when a connection is down.
|
||||
// ready is closed and becomes nil when a new transport is up or failed
|
||||
// due to timeout.
|
||||
ready chan struct{}
|
||||
transport transport.ClientTransport
|
||||
}
|
||||
|
||||
// NewConn creates a Conn.
|
||||
func NewConn(cc *ClientConn) (*Conn, error) {
|
||||
if cc.target == "" {
|
||||
return nil, ErrUnspecTarget
|
||||
}
|
||||
c := &Conn{
|
||||
target: cc.target,
|
||||
dopts: cc.dopts,
|
||||
resetChan: make(chan int, 1),
|
||||
shutdownChan: make(chan struct{}),
|
||||
}
|
||||
if EnableTracing {
|
||||
c.events = trace.NewEventLog("grpc.ClientConn", c.target)
|
||||
}
|
||||
if !c.dopts.insecure {
|
||||
var ok bool
|
||||
for _, cd := range c.dopts.copts.AuthOptions {
|
||||
if _, ok = cd.(credentials.TransportAuthenticator); ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return nil, ErrNoTransportSecurity
|
||||
}
|
||||
} else {
|
||||
for _, cd := range c.dopts.copts.AuthOptions {
|
||||
if cd.RequireTransportSecurity() {
|
||||
return nil, ErrCredentialsMisuse
|
||||
}
|
||||
}
|
||||
}
|
||||
c.stateCV = sync.NewCond(&c.mu)
|
||||
if c.dopts.block {
|
||||
if err := c.resetTransport(false); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Start to monitor the error status of transport.
|
||||
go c.transportMonitor()
|
||||
} else {
|
||||
// Start a goroutine connecting to the server asynchronously.
|
||||
go func() {
|
||||
if err := c.resetTransport(false); err != nil {
|
||||
grpclog.Printf("Failed to dial %s: %v; please retry.", c.target, err)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
c.transportMonitor()
|
||||
}()
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// printf records an event in cc's event log, unless cc has been closed.
|
||||
// REQUIRES cc.mu is held.
|
||||
func (cc *Conn) printf(format string, a ...interface{}) {
|
||||
if cc.events != nil {
|
||||
cc.events.Printf(format, a...)
|
||||
// printf records an event in ac's event log, unless ac has been closed.
|
||||
// REQUIRES ac.mu is held.
|
||||
func (ac *addrConn) printf(format string, a ...interface{}) {
|
||||
if ac.events != nil {
|
||||
ac.events.Printf(format, a...)
|
||||
}
|
||||
}
|
||||
|
||||
// errorf records an error in cc's event log, unless cc has been closed.
|
||||
// REQUIRES cc.mu is held.
|
||||
func (cc *Conn) errorf(format string, a ...interface{}) {
|
||||
if cc.events != nil {
|
||||
cc.events.Errorf(format, a...)
|
||||
// errorf records an error in ac's event log, unless ac has been closed.
|
||||
// REQUIRES ac.mu is held.
|
||||
func (ac *addrConn) errorf(format string, a ...interface{}) {
|
||||
if ac.events != nil {
|
||||
ac.events.Errorf(format, a...)
|
||||
}
|
||||
}
|
||||
|
||||
// State returns the connectivity state of the Conn
|
||||
func (cc *Conn) State() ConnectivityState {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
return cc.state
|
||||
// getState returns the connectivity state of the Conn
|
||||
func (ac *addrConn) getState() ConnectivityState {
|
||||
ac.mu.Lock()
|
||||
defer ac.mu.Unlock()
|
||||
return ac.state
|
||||
}
|
||||
|
||||
// WaitForStateChange blocks until the state changes to something other than the sourceState.
|
||||
func (cc *Conn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
if sourceState != cc.state {
|
||||
return cc.state, nil
|
||||
// waitForStateChange blocks until the state changes to something other than the sourceState.
|
||||
func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
|
||||
ac.mu.Lock()
|
||||
defer ac.mu.Unlock()
|
||||
if sourceState != ac.state {
|
||||
return ac.state, nil
|
||||
}
|
||||
done := make(chan struct{})
|
||||
var err error
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cc.mu.Lock()
|
||||
ac.mu.Lock()
|
||||
err = ctx.Err()
|
||||
cc.stateCV.Broadcast()
|
||||
cc.mu.Unlock()
|
||||
ac.stateCV.Broadcast()
|
||||
ac.mu.Unlock()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
for sourceState == cc.state {
|
||||
cc.stateCV.Wait()
|
||||
for sourceState == ac.state {
|
||||
ac.stateCV.Wait()
|
||||
if err != nil {
|
||||
return cc.state, err
|
||||
return ac.state, err
|
||||
}
|
||||
}
|
||||
return cc.state, nil
|
||||
return ac.state, nil
|
||||
}
|
||||
|
||||
// NotifyReset tries to signal the underlying transport needs to be reset due to
|
||||
// for example a name resolution change in flight.
|
||||
func (cc *Conn) NotifyReset() {
|
||||
select {
|
||||
case cc.resetChan <- 0:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *Conn) resetTransport(closeTransport bool) error {
|
||||
func (ac *addrConn) resetTransport(closeTransport bool) error {
|
||||
var retries int
|
||||
start := time.Now()
|
||||
for {
|
||||
cc.mu.Lock()
|
||||
cc.printf("connecting")
|
||||
if cc.state == Shutdown {
|
||||
// cc.Close() has been invoked.
|
||||
cc.mu.Unlock()
|
||||
return ErrClientConnClosing
|
||||
ac.mu.Lock()
|
||||
ac.printf("connecting")
|
||||
if ac.state == Shutdown {
|
||||
// ac.tearDown(...) has been invoked.
|
||||
ac.mu.Unlock()
|
||||
return errConnClosing
|
||||
}
|
||||
cc.state = Connecting
|
||||
cc.stateCV.Broadcast()
|
||||
cc.mu.Unlock()
|
||||
if closeTransport {
|
||||
cc.transport.Close()
|
||||
if ac.down != nil {
|
||||
ac.down(downErrorf(false, true, "%v", errNetworkIO))
|
||||
ac.down = nil
|
||||
}
|
||||
// Adjust timeout for the current try.
|
||||
copts := cc.dopts.copts
|
||||
if copts.Timeout < 0 {
|
||||
cc.Close()
|
||||
return ErrClientConnTimeout
|
||||
ac.state = Connecting
|
||||
ac.stateCV.Broadcast()
|
||||
t := ac.transport
|
||||
ac.mu.Unlock()
|
||||
if closeTransport && t != nil {
|
||||
t.Close()
|
||||
}
|
||||
if copts.Timeout > 0 {
|
||||
copts.Timeout -= time.Since(start)
|
||||
if copts.Timeout <= 0 {
|
||||
cc.Close()
|
||||
return ErrClientConnTimeout
|
||||
}
|
||||
}
|
||||
sleepTime := cc.dopts.bs.backoff(retries)
|
||||
timeout := sleepTime
|
||||
if timeout < minConnectTimeout {
|
||||
timeout = minConnectTimeout
|
||||
}
|
||||
if copts.Timeout == 0 || copts.Timeout > timeout {
|
||||
copts.Timeout = timeout
|
||||
sleepTime := ac.dopts.bs.backoff(retries)
|
||||
ac.dopts.copts.Timeout = sleepTime
|
||||
if sleepTime < minConnectTimeout {
|
||||
ac.dopts.copts.Timeout = minConnectTimeout
|
||||
}
|
||||
connectTime := time.Now()
|
||||
addr, err := cc.dopts.picker.PickAddr()
|
||||
var newTransport transport.ClientTransport
|
||||
if err == nil {
|
||||
newTransport, err = transport.NewClientTransport(addr, &copts)
|
||||
}
|
||||
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts)
|
||||
if err != nil {
|
||||
cc.mu.Lock()
|
||||
if cc.state == Shutdown {
|
||||
// cc.Close() has been invoked.
|
||||
cc.mu.Unlock()
|
||||
return ErrClientConnClosing
|
||||
ac.mu.Lock()
|
||||
if ac.state == Shutdown {
|
||||
// ac.tearDown(...) has been invoked.
|
||||
ac.mu.Unlock()
|
||||
return errConnClosing
|
||||
}
|
||||
cc.errorf("transient failure: %v", err)
|
||||
cc.state = TransientFailure
|
||||
cc.stateCV.Broadcast()
|
||||
if cc.ready != nil {
|
||||
close(cc.ready)
|
||||
cc.ready = nil
|
||||
ac.errorf("transient failure: %v", err)
|
||||
ac.state = TransientFailure
|
||||
ac.stateCV.Broadcast()
|
||||
if ac.ready != nil {
|
||||
close(ac.ready)
|
||||
ac.ready = nil
|
||||
}
|
||||
cc.mu.Unlock()
|
||||
ac.mu.Unlock()
|
||||
sleepTime -= time.Since(connectTime)
|
||||
if sleepTime < 0 {
|
||||
sleepTime = 0
|
||||
}
|
||||
// Fail early before falling into sleep.
|
||||
if cc.dopts.copts.Timeout > 0 && cc.dopts.copts.Timeout < sleepTime+time.Since(start) {
|
||||
cc.mu.Lock()
|
||||
cc.errorf("connection timeout")
|
||||
cc.mu.Unlock()
|
||||
cc.Close()
|
||||
return ErrClientConnTimeout
|
||||
}
|
||||
closeTransport = false
|
||||
time.Sleep(sleepTime)
|
||||
select {
|
||||
case <-time.After(sleepTime):
|
||||
case <-ac.shutdownChan:
|
||||
}
|
||||
retries++
|
||||
grpclog.Printf("grpc: Conn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, cc.target)
|
||||
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
|
||||
continue
|
||||
}
|
||||
cc.mu.Lock()
|
||||
cc.printf("ready")
|
||||
if cc.state == Shutdown {
|
||||
// cc.Close() has been invoked.
|
||||
cc.mu.Unlock()
|
||||
ac.mu.Lock()
|
||||
ac.printf("ready")
|
||||
if ac.state == Shutdown {
|
||||
// ac.tearDown(...) has been invoked.
|
||||
ac.mu.Unlock()
|
||||
newTransport.Close()
|
||||
return ErrClientConnClosing
|
||||
return errConnClosing
|
||||
}
|
||||
cc.state = Ready
|
||||
cc.stateCV.Broadcast()
|
||||
cc.transport = newTransport
|
||||
if cc.ready != nil {
|
||||
close(cc.ready)
|
||||
cc.ready = nil
|
||||
ac.state = Ready
|
||||
ac.stateCV.Broadcast()
|
||||
ac.transport = newTransport
|
||||
if ac.ready != nil {
|
||||
close(ac.ready)
|
||||
ac.ready = nil
|
||||
}
|
||||
cc.mu.Unlock()
|
||||
ac.down = ac.cc.balancer.Up(ac.addr)
|
||||
ac.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *Conn) reconnect() bool {
|
||||
cc.mu.Lock()
|
||||
if cc.state == Shutdown {
|
||||
// cc.Close() has been invoked.
|
||||
cc.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
cc.state = TransientFailure
|
||||
cc.stateCV.Broadcast()
|
||||
cc.mu.Unlock()
|
||||
if err := cc.resetTransport(true); err != nil {
|
||||
// The ClientConn is closing.
|
||||
cc.mu.Lock()
|
||||
cc.printf("transport exiting: %v", err)
|
||||
cc.mu.Unlock()
|
||||
grpclog.Printf("grpc: Conn.transportMonitor exits due to: %v", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Run in a goroutine to track the error in transport and create the
|
||||
// new transport if an error happens. It returns when the channel is closing.
|
||||
func (cc *Conn) transportMonitor() {
|
||||
func (ac *addrConn) transportMonitor() {
|
||||
for {
|
||||
ac.mu.Lock()
|
||||
t := ac.transport
|
||||
ac.mu.Unlock()
|
||||
select {
|
||||
// shutdownChan is needed to detect the teardown when
|
||||
// the ClientConn is idle (i.e., no RPC in flight).
|
||||
case <-cc.shutdownChan:
|
||||
// the addrConn is idle (i.e., no RPC in flight).
|
||||
case <-ac.shutdownChan:
|
||||
return
|
||||
case <-cc.resetChan:
|
||||
if !cc.reconnect() {
|
||||
case <-t.Error():
|
||||
ac.mu.Lock()
|
||||
if ac.state == Shutdown {
|
||||
// ac.tearDown(...) has been invoked.
|
||||
ac.mu.Unlock()
|
||||
return
|
||||
}
|
||||
case <-cc.transport.Error():
|
||||
if !cc.reconnect() {
|
||||
ac.state = TransientFailure
|
||||
ac.stateCV.Broadcast()
|
||||
ac.mu.Unlock()
|
||||
if err := ac.resetTransport(true); err != nil {
|
||||
ac.mu.Lock()
|
||||
ac.printf("transport exiting: %v", err)
|
||||
ac.mu.Unlock()
|
||||
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
|
||||
return
|
||||
}
|
||||
// Tries to drain reset signal if there is any since it is out-dated.
|
||||
select {
|
||||
case <-cc.resetChan:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait blocks until i) the new transport is up or ii) ctx is done or iii) cc is closed.
|
||||
func (cc *Conn) Wait(ctx context.Context) (transport.ClientTransport, error) {
|
||||
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed.
|
||||
func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) {
|
||||
for {
|
||||
cc.mu.Lock()
|
||||
ac.mu.Lock()
|
||||
switch {
|
||||
case cc.state == Shutdown:
|
||||
cc.mu.Unlock()
|
||||
return nil, ErrClientConnClosing
|
||||
case cc.state == Ready:
|
||||
ct := cc.transport
|
||||
cc.mu.Unlock()
|
||||
case ac.state == Shutdown:
|
||||
ac.mu.Unlock()
|
||||
return nil, errConnClosing
|
||||
case ac.state == Ready:
|
||||
ct := ac.transport
|
||||
ac.mu.Unlock()
|
||||
return ct, nil
|
||||
default:
|
||||
ready := cc.ready
|
||||
ready := ac.ready
|
||||
if ready == nil {
|
||||
ready = make(chan struct{})
|
||||
cc.ready = ready
|
||||
ac.ready = ready
|
||||
}
|
||||
cc.mu.Unlock()
|
||||
ac.mu.Unlock()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, transport.ContextErr(ctx.Err())
|
||||
|
|
@ -592,32 +678,46 @@ func (cc *Conn) Wait(ctx context.Context) (transport.ClientTransport, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Close starts to tear down the Conn. Returns ErrClientConnClosing if
|
||||
// it has been closed (mostly due to dial time-out).
|
||||
// tearDown starts to tear down the addrConn.
|
||||
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
|
||||
// some edge cases (e.g., the caller opens and closes many ClientConn's in a
|
||||
// some edge cases (e.g., the caller opens and closes many addrConn's in a
|
||||
// tight loop.
|
||||
func (cc *Conn) Close() error {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
if cc.state == Shutdown {
|
||||
return ErrClientConnClosing
|
||||
func (ac *addrConn) tearDown(err error) {
|
||||
ac.mu.Lock()
|
||||
defer func() {
|
||||
ac.mu.Unlock()
|
||||
ac.cc.mu.Lock()
|
||||
if ac.cc.conns != nil {
|
||||
delete(ac.cc.conns, ac.addr)
|
||||
}
|
||||
ac.cc.mu.Unlock()
|
||||
}()
|
||||
if ac.state == Shutdown {
|
||||
return
|
||||
}
|
||||
cc.state = Shutdown
|
||||
cc.stateCV.Broadcast()
|
||||
if cc.events != nil {
|
||||
cc.events.Finish()
|
||||
cc.events = nil
|
||||
ac.state = Shutdown
|
||||
if ac.down != nil {
|
||||
ac.down(downErrorf(false, false, "%v", err))
|
||||
ac.down = nil
|
||||
}
|
||||
if cc.ready != nil {
|
||||
close(cc.ready)
|
||||
cc.ready = nil
|
||||
ac.stateCV.Broadcast()
|
||||
if ac.events != nil {
|
||||
ac.events.Finish()
|
||||
ac.events = nil
|
||||
}
|
||||
if cc.transport != nil {
|
||||
cc.transport.Close()
|
||||
if ac.ready != nil {
|
||||
close(ac.ready)
|
||||
ac.ready = nil
|
||||
}
|
||||
if cc.shutdownChan != nil {
|
||||
close(cc.shutdownChan)
|
||||
if ac.transport != nil {
|
||||
if err == errConnDrain {
|
||||
ac.transport.GracefulClose()
|
||||
} else {
|
||||
ac.transport.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
if ac.shutdownChan != nil {
|
||||
close(ac.shutdownChan)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,9 +54,9 @@ var (
|
|||
alpnProtoStr = []string{"h2"}
|
||||
)
|
||||
|
||||
// Credentials defines the common interface all supported credentials must
|
||||
// implement.
|
||||
type Credentials interface {
|
||||
// PerRPCCredentials defines the common interface for the credentials which need to
|
||||
// attach security information to every RPC (e.g., oauth2).
|
||||
type PerRPCCredentials interface {
|
||||
// GetRequestMetadata gets the current request metadata, refreshing
|
||||
// tokens if required. This should be called by the transport layer on
|
||||
// each request, and the data should be populated in headers or other
|
||||
|
|
@ -87,9 +87,9 @@ type AuthInfo interface {
|
|||
AuthType() string
|
||||
}
|
||||
|
||||
// TransportAuthenticator defines the common interface for all the live gRPC wire
|
||||
// TransportCredentials defines the common interface for all the live gRPC wire
|
||||
// protocols and supported transport security protocols (e.g., TLS, SSL).
|
||||
type TransportAuthenticator interface {
|
||||
type TransportCredentials interface {
|
||||
// ClientHandshake does the authentication handshake specified by the corresponding
|
||||
// authentication protocol on rawConn for clients. It returns the authenticated
|
||||
// connection and the corresponding auth information about the connection.
|
||||
|
|
@ -98,9 +98,8 @@ type TransportAuthenticator interface {
|
|||
// the authenticated connection and the corresponding auth information about
|
||||
// the connection.
|
||||
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
||||
// Info provides the ProtocolInfo of this TransportAuthenticator.
|
||||
// Info provides the ProtocolInfo of this TransportCredentials.
|
||||
Info() ProtocolInfo
|
||||
Credentials
|
||||
}
|
||||
|
||||
// TLSInfo contains the auth information for a TLS authenticated connection.
|
||||
|
|
@ -109,6 +108,7 @@ type TLSInfo struct {
|
|||
State tls.ConnectionState
|
||||
}
|
||||
|
||||
// AuthType returns the type of TLSInfo as a string.
|
||||
func (t TLSInfo) AuthType() string {
|
||||
return "tls"
|
||||
}
|
||||
|
|
@ -185,20 +185,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
|||
return conn, TLSInfo{conn.ConnectionState()}, nil
|
||||
}
|
||||
|
||||
// NewTLS uses c to construct a TransportAuthenticator based on TLS.
|
||||
func NewTLS(c *tls.Config) TransportAuthenticator {
|
||||
// NewTLS uses c to construct a TransportCredentials based on TLS.
|
||||
func NewTLS(c *tls.Config) TransportCredentials {
|
||||
tc := &tlsCreds{*c}
|
||||
tc.config.NextProtos = alpnProtoStr
|
||||
return tc
|
||||
}
|
||||
|
||||
// NewClientTLSFromCert constructs a TLS from the input certificate for client.
|
||||
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator {
|
||||
func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials {
|
||||
return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp})
|
||||
}
|
||||
|
||||
// NewClientTLSFromFile constructs a TLS from the input certificate file for client.
|
||||
func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, error) {
|
||||
func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) {
|
||||
b, err := ioutil.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -211,13 +211,13 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator,
|
|||
}
|
||||
|
||||
// NewServerTLSFromCert constructs a TLS from the input certificate for server.
|
||||
func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator {
|
||||
func NewServerTLSFromCert(cert *tls.Certificate) TransportCredentials {
|
||||
return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}})
|
||||
}
|
||||
|
||||
// NewServerTLSFromFile constructs a TLS from the input certificate file and key
|
||||
// file for server.
|
||||
func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, error) {
|
||||
func NewServerTLSFromFile(certFile, keyFile string) (TransportCredentials, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ type Resolver interface {
|
|||
// Watcher watches for the updates on the specified target.
|
||||
type Watcher interface {
|
||||
// Next blocks until an update or error happens. It may return one or more
|
||||
// updates. The first call should get the full set of the results.
|
||||
// updates. The first call should get the full set of the results. It should
|
||||
// return an error if and only if Watcher cannot recover.
|
||||
Next() ([]*Update, error)
|
||||
// Close closes the Watcher.
|
||||
Close()
|
||||
|
|
|
|||
|
|
@ -1,243 +0,0 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2014, Google Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
*
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following disclaimer
|
||||
* in the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* * Neither the name of Google Inc. nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
*/
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/naming"
|
||||
"google.golang.org/grpc/transport"
|
||||
)
|
||||
|
||||
// Picker picks a Conn for RPC requests.
|
||||
// This is EXPERIMENTAL and please do not implement your own Picker for now.
|
||||
type Picker interface {
|
||||
// Init does initial processing for the Picker, e.g., initiate some connections.
|
||||
Init(cc *ClientConn) error
|
||||
// Pick blocks until either a transport.ClientTransport is ready for the upcoming RPC
|
||||
// or some error happens.
|
||||
Pick(ctx context.Context) (transport.ClientTransport, error)
|
||||
// PickAddr picks a peer address for connecting. This will be called repeated for
|
||||
// connecting/reconnecting.
|
||||
PickAddr() (string, error)
|
||||
// State returns the connectivity state of the underlying connections.
|
||||
State() (ConnectivityState, error)
|
||||
// WaitForStateChange blocks until the state changes to something other than
|
||||
// the sourceState. It returns the new state or error.
|
||||
WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error)
|
||||
// Close closes all the Conn's owned by this Picker.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// unicastPicker is the default Picker which is used when there is no custom Picker
|
||||
// specified by users. It always picks the same Conn.
|
||||
type unicastPicker struct {
|
||||
target string
|
||||
conn *Conn
|
||||
}
|
||||
|
||||
func (p *unicastPicker) Init(cc *ClientConn) error {
|
||||
c, err := NewConn(cc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.conn = c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *unicastPicker) Pick(ctx context.Context) (transport.ClientTransport, error) {
|
||||
return p.conn.Wait(ctx)
|
||||
}
|
||||
|
||||
func (p *unicastPicker) PickAddr() (string, error) {
|
||||
return p.target, nil
|
||||
}
|
||||
|
||||
func (p *unicastPicker) State() (ConnectivityState, error) {
|
||||
return p.conn.State(), nil
|
||||
}
|
||||
|
||||
func (p *unicastPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
|
||||
return p.conn.WaitForStateChange(ctx, sourceState)
|
||||
}
|
||||
|
||||
func (p *unicastPicker) Close() error {
|
||||
if p.conn != nil {
|
||||
return p.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// unicastNamingPicker picks an address from a name resolver to set up the connection.
|
||||
type unicastNamingPicker struct {
|
||||
cc *ClientConn
|
||||
resolver naming.Resolver
|
||||
watcher naming.Watcher
|
||||
mu sync.Mutex
|
||||
// The list of the addresses are obtained from watcher.
|
||||
addrs *list.List
|
||||
// It tracks the current picked addr by PickAddr(). The next PickAddr may
|
||||
// push it forward on addrs.
|
||||
pickedAddr *list.Element
|
||||
conn *Conn
|
||||
}
|
||||
|
||||
// NewUnicastNamingPicker creates a Picker to pick addresses from a name resolver
|
||||
// to connect.
|
||||
func NewUnicastNamingPicker(r naming.Resolver) Picker {
|
||||
return &unicastNamingPicker{
|
||||
resolver: r,
|
||||
addrs: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
type addrInfo struct {
|
||||
addr string
|
||||
// Set to true if this addrInfo needs to be deleted in the next PickAddrr() call.
|
||||
deleting bool
|
||||
}
|
||||
|
||||
// processUpdates calls Watcher.Next() once and processes the obtained updates.
|
||||
func (p *unicastNamingPicker) processUpdates() error {
|
||||
updates, err := p.watcher.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, update := range updates {
|
||||
switch update.Op {
|
||||
case naming.Add:
|
||||
p.mu.Lock()
|
||||
p.addrs.PushBack(&addrInfo{
|
||||
addr: update.Addr,
|
||||
})
|
||||
p.mu.Unlock()
|
||||
// Initial connection setup
|
||||
if p.conn == nil {
|
||||
conn, err := NewConn(p.cc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.conn = conn
|
||||
}
|
||||
case naming.Delete:
|
||||
p.mu.Lock()
|
||||
for e := p.addrs.Front(); e != nil; e = e.Next() {
|
||||
if update.Addr == e.Value.(*addrInfo).addr {
|
||||
if e == p.pickedAddr {
|
||||
// Do not remove the element now if it is the current picked
|
||||
// one. We leave the deletion to the next PickAddr() call.
|
||||
e.Value.(*addrInfo).deleting = true
|
||||
// Notify Conn to close it. All the live RPCs on this connection
|
||||
// will be aborted.
|
||||
p.conn.NotifyReset()
|
||||
} else {
|
||||
p.addrs.Remove(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
p.mu.Unlock()
|
||||
default:
|
||||
grpclog.Println("Unknown update.Op ", update.Op)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitor runs in a standalone goroutine to keep watching name resolution updates until the watcher
|
||||
// is closed.
|
||||
func (p *unicastNamingPicker) monitor() {
|
||||
for {
|
||||
if err := p.processUpdates(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *unicastNamingPicker) Init(cc *ClientConn) error {
|
||||
w, err := p.resolver.Resolve(cc.target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.watcher = w
|
||||
p.cc = cc
|
||||
// Get the initial name resolution.
|
||||
if err := p.processUpdates(); err != nil {
|
||||
return err
|
||||
}
|
||||
go p.monitor()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *unicastNamingPicker) Pick(ctx context.Context) (transport.ClientTransport, error) {
|
||||
return p.conn.Wait(ctx)
|
||||
}
|
||||
|
||||
func (p *unicastNamingPicker) PickAddr() (string, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.pickedAddr == nil {
|
||||
p.pickedAddr = p.addrs.Front()
|
||||
} else {
|
||||
pa := p.pickedAddr
|
||||
p.pickedAddr = pa.Next()
|
||||
if pa.Value.(*addrInfo).deleting {
|
||||
p.addrs.Remove(pa)
|
||||
}
|
||||
if p.pickedAddr == nil {
|
||||
p.pickedAddr = p.addrs.Front()
|
||||
}
|
||||
}
|
||||
if p.pickedAddr == nil {
|
||||
return "", fmt.Errorf("there is no address available to pick")
|
||||
}
|
||||
return p.pickedAddr.Value.(*addrInfo).addr, nil
|
||||
}
|
||||
|
||||
func (p *unicastNamingPicker) State() (ConnectivityState, error) {
|
||||
return 0, fmt.Errorf("State() is not supported for unicastNamingPicker")
|
||||
}
|
||||
|
||||
func (p *unicastNamingPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
|
||||
return 0, fmt.Errorf("WaitForStateChange is not supported for unicastNamingPciker")
|
||||
}
|
||||
|
||||
func (p *unicastNamingPicker) Close() error {
|
||||
p.watcher.Close()
|
||||
p.conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
|
@ -61,7 +61,7 @@ type Codec interface {
|
|||
String() string
|
||||
}
|
||||
|
||||
// protoCodec is a Codec implemetation with protobuf. It is the default codec for gRPC.
|
||||
// protoCodec is a Codec implementation with protobuf. It is the default codec for gRPC.
|
||||
type protoCodec struct{}
|
||||
|
||||
func (protoCodec) Marshal(v interface{}) ([]byte, error) {
|
||||
|
|
@ -187,7 +187,7 @@ const (
|
|||
compressionMade
|
||||
)
|
||||
|
||||
// parser reads complelete gRPC messages from the underlying reader.
|
||||
// parser reads complete gRPC messages from the underlying reader.
|
||||
type parser struct {
|
||||
// r is the underlying reader.
|
||||
// See the comment on recvMsg for the permissible
|
||||
|
|
@ -284,14 +284,11 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
|
|||
switch pf {
|
||||
case compressionNone:
|
||||
case compressionMade:
|
||||
if recvCompress == "" {
|
||||
return transport.StreamErrorf(codes.InvalidArgument, "grpc: invalid grpc-encoding %q with compression enabled", recvCompress)
|
||||
}
|
||||
if dc == nil || recvCompress != dc.Type() {
|
||||
return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
|
||||
return transport.StreamErrorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
|
||||
}
|
||||
default:
|
||||
return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
|
||||
return transport.StreamErrorf(codes.Internal, "grpc: received unexpected payload format %d", pf)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ type ServiceDesc struct {
|
|||
HandlerType interface{}
|
||||
Methods []MethodDesc
|
||||
Streams []StreamDesc
|
||||
Metadata interface{}
|
||||
}
|
||||
|
||||
// service consists of the information of the server serving this service and
|
||||
|
|
@ -95,10 +96,12 @@ type Server struct {
|
|||
}
|
||||
|
||||
type options struct {
|
||||
creds credentials.Credentials
|
||||
creds credentials.TransportCredentials
|
||||
codec Codec
|
||||
cp Compressor
|
||||
dc Decompressor
|
||||
unaryInt UnaryServerInterceptor
|
||||
streamInt StreamServerInterceptor
|
||||
maxConcurrentStreams uint32
|
||||
useHandlerImpl bool // use http.Handler-based server
|
||||
}
|
||||
|
|
@ -113,12 +116,14 @@ func CustomCodec(codec Codec) ServerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
|
||||
func RPCCompressor(cp Compressor) ServerOption {
|
||||
return func(o *options) {
|
||||
o.cp = cp
|
||||
}
|
||||
}
|
||||
|
||||
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
|
||||
func RPCDecompressor(dc Decompressor) ServerOption {
|
||||
return func(o *options) {
|
||||
o.dc = dc
|
||||
|
|
@ -134,12 +139,35 @@ func MaxConcurrentStreams(n uint32) ServerOption {
|
|||
}
|
||||
|
||||
// Creds returns a ServerOption that sets credentials for server connections.
|
||||
func Creds(c credentials.Credentials) ServerOption {
|
||||
func Creds(c credentials.TransportCredentials) ServerOption {
|
||||
return func(o *options) {
|
||||
o.creds = c
|
||||
}
|
||||
}
|
||||
|
||||
// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the
|
||||
// server. Only one unary interceptor can be installed. The construction of multiple
|
||||
// interceptors (e.g., chaining) can be implemented at the caller.
|
||||
func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
|
||||
return func(o *options) {
|
||||
if o.unaryInt != nil {
|
||||
panic("The unary server interceptor has been set.")
|
||||
}
|
||||
o.unaryInt = i
|
||||
}
|
||||
}
|
||||
|
||||
// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the
|
||||
// server. Only one stream interceptor can be installed.
|
||||
func StreamInterceptor(i StreamServerInterceptor) ServerOption {
|
||||
return func(o *options) {
|
||||
if o.streamInt != nil {
|
||||
panic("The stream server interceptor has been set.")
|
||||
}
|
||||
o.streamInt = i
|
||||
}
|
||||
}
|
||||
|
||||
// NewServer creates a gRPC server which has no service registered and has not
|
||||
// started to accept requests yet.
|
||||
func NewServer(opt ...ServerOption) *Server {
|
||||
|
|
@ -222,22 +250,23 @@ var (
|
|||
)
|
||||
|
||||
func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
creds, ok := s.opts.creds.(credentials.TransportAuthenticator)
|
||||
if !ok {
|
||||
if s.opts.creds == nil {
|
||||
return rawConn, nil, nil
|
||||
}
|
||||
return creds.ServerHandshake(rawConn)
|
||||
return s.opts.creds.ServerHandshake(rawConn)
|
||||
}
|
||||
|
||||
// Serve accepts incoming connections on the listener lis, creating a new
|
||||
// ServerTransport and service goroutine for each. The service goroutines
|
||||
// read gRPC requests and then call the registered handlers to reply to them.
|
||||
// Service returns when lis.Accept fails.
|
||||
// Service returns when lis.Accept fails. lis will be closed when
|
||||
// this method returns.
|
||||
func (s *Server) Serve(lis net.Listener) error {
|
||||
s.mu.Lock()
|
||||
s.printf("serving")
|
||||
if s.lis == nil {
|
||||
s.mu.Unlock()
|
||||
lis.Close()
|
||||
return ErrServerStopped
|
||||
}
|
||||
s.lis[lis] = true
|
||||
|
|
@ -435,6 +464,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
}
|
||||
}()
|
||||
}
|
||||
if s.opts.cp != nil {
|
||||
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
|
||||
stream.SetSendCompress(s.opts.cp.Type())
|
||||
}
|
||||
p := &parser{r: stream}
|
||||
for {
|
||||
pf, req, err := p.recvMsg()
|
||||
|
|
@ -494,7 +527,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
}
|
||||
return nil
|
||||
}
|
||||
reply, appErr := md.Handler(srv.server, stream.Context(), df, nil)
|
||||
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
|
||||
if appErr != nil {
|
||||
if err, ok := appErr.(rpcError); ok {
|
||||
statusCode = err.code
|
||||
|
|
@ -520,9 +553,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
Last: true,
|
||||
Delay: false,
|
||||
}
|
||||
if s.opts.cp != nil {
|
||||
stream.SetSendCompress(s.opts.cp.Type())
|
||||
}
|
||||
if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
|
||||
switch err := err.(type) {
|
||||
case transport.ConnectionError:
|
||||
|
|
@ -572,7 +602,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||
ss.mu.Unlock()
|
||||
}()
|
||||
}
|
||||
if appErr := sd.Handler(srv.server, ss); appErr != nil {
|
||||
var appErr error
|
||||
if s.opts.streamInt == nil {
|
||||
appErr = sd.Handler(srv.server, ss)
|
||||
} else {
|
||||
info := &StreamServerInfo{
|
||||
FullMethod: stream.Method(),
|
||||
IsClientStream: sd.ClientStreams,
|
||||
IsServerStream: sd.ServerStreams,
|
||||
}
|
||||
appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
|
||||
}
|
||||
if appErr != nil {
|
||||
if err, ok := appErr.(rpcError); ok {
|
||||
ss.statusCode = err.code
|
||||
ss.statusDesc = err.desc
|
||||
|
|
|
|||
|
|
@ -79,9 +79,9 @@ type Stream interface {
|
|||
RecvMsg(m interface{}) error
|
||||
}
|
||||
|
||||
// ClientStream defines the interface a client stream has to satify.
|
||||
// ClientStream defines the interface a client stream has to satisfy.
|
||||
type ClientStream interface {
|
||||
// Header returns the header metedata received from the server if there
|
||||
// Header returns the header metadata received from the server if there
|
||||
// is any. It blocks if the metadata is not ready to read.
|
||||
Header() (metadata.MD, error)
|
||||
// Trailer returns the trailer metadata from the server. It must be called
|
||||
|
|
@ -103,12 +103,16 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||
var (
|
||||
t transport.ClientTransport
|
||||
err error
|
||||
put func()
|
||||
)
|
||||
t, err = cc.dopts.picker.Pick(ctx)
|
||||
// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
|
||||
gopts := BalancerGetOptions{
|
||||
BlockingWait: false,
|
||||
}
|
||||
t, put, err = cc.getTransport(ctx, gopts)
|
||||
if err != nil {
|
||||
return nil, toRPCErr(err)
|
||||
}
|
||||
// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
|
||||
callHdr := &transport.CallHdr{
|
||||
Host: cc.authority,
|
||||
Method: method,
|
||||
|
|
@ -119,6 +123,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||
}
|
||||
cs := &clientStream{
|
||||
desc: desc,
|
||||
put: put,
|
||||
codec: cc.dopts.codec,
|
||||
cp: cc.dopts.cp,
|
||||
dc: cc.dopts.dc,
|
||||
|
|
@ -174,6 +179,7 @@ type clientStream struct {
|
|||
tracing bool // set to EnableTracing when the clientStream is created.
|
||||
|
||||
mu sync.Mutex
|
||||
put func()
|
||||
closed bool
|
||||
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
|
||||
// and is set to nil when the clientStream's finish method is called.
|
||||
|
|
@ -311,6 +317,10 @@ func (cs *clientStream) finish(err error) {
|
|||
}
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
if cs.put != nil {
|
||||
cs.put()
|
||||
cs.put = nil
|
||||
}
|
||||
if cs.trInfo.tr != nil {
|
||||
if err == nil || err == io.EOF {
|
||||
cs.trInfo.tr.LazyPrintf("RPC: [OK]")
|
||||
|
|
|
|||
|
|
@ -101,9 +101,8 @@ type payload struct {
|
|||
func (p payload) String() string {
|
||||
if p.sent {
|
||||
return fmt.Sprintf("sent: %v", p.msg)
|
||||
} else {
|
||||
return fmt.Sprintf("recv: %v", p.msg)
|
||||
}
|
||||
return fmt.Sprintf("recv: %v", p.msg)
|
||||
}
|
||||
|
||||
type fmtStringer struct {
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
|
|||
if r.Method != "POST" {
|
||||
return nil, errors.New("invalid gRPC request method")
|
||||
}
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
|
||||
if !validContentType(r.Header.Get("Content-Type")) {
|
||||
return nil, errors.New("invalid gRPC request content-type")
|
||||
}
|
||||
if _, ok := w.(http.Flusher); !ok {
|
||||
|
|
@ -92,9 +92,12 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
|
|||
}
|
||||
|
||||
var metakv []string
|
||||
if r.Host != "" {
|
||||
metakv = append(metakv, ":authority", r.Host)
|
||||
}
|
||||
for k, vv := range r.Header {
|
||||
k = strings.ToLower(k)
|
||||
if isReservedHeader(k) {
|
||||
if isReservedHeader(k) && !isWhitelistedPseudoHeader(k) {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
|
|
@ -108,7 +111,6 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
|
|||
}
|
||||
}
|
||||
metakv = append(metakv, k, v)
|
||||
|
||||
}
|
||||
}
|
||||
st.headerMD = metadata.Pairs(metakv...)
|
||||
|
|
@ -196,6 +198,10 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
|
|||
}
|
||||
if md := s.Trailer(); len(md) > 0 {
|
||||
for k, vv := range md {
|
||||
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
|
||||
if isReservedHeader(k) {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
// http2 ResponseWriter mechanism to
|
||||
// send undeclared Trailers after the
|
||||
|
|
@ -249,6 +255,10 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
|||
ht.writeCommonHeaders(s)
|
||||
h := ht.rw.Header()
|
||||
for k, vv := range md {
|
||||
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
|
||||
if isReservedHeader(k) {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
h.Add(k, v)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ package transport
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
|
|
@ -89,7 +88,7 @@ type http2Client struct {
|
|||
// The scheme used: https if TLS is on, http otherwise.
|
||||
scheme string
|
||||
|
||||
authCreds []credentials.Credentials
|
||||
creds []credentials.PerRPCCredentials
|
||||
|
||||
mu sync.Mutex // guard the following variables
|
||||
state transportState // the state of underlying connection
|
||||
|
|
@ -118,19 +117,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||
}
|
||||
var authInfo credentials.AuthInfo
|
||||
for _, c := range opts.AuthOptions {
|
||||
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
||||
scheme = "https"
|
||||
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are
|
||||
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
||||
// place the ClientTransport construction into a separate function to make
|
||||
// things clear.
|
||||
if timeout > 0 {
|
||||
timeout -= time.Since(startT)
|
||||
}
|
||||
conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout)
|
||||
break
|
||||
if opts.TransportCredentials != nil {
|
||||
scheme = "https"
|
||||
if timeout > 0 {
|
||||
timeout -= time.Since(startT)
|
||||
}
|
||||
conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
|
||||
}
|
||||
if connErr != nil {
|
||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||
|
|
@ -140,29 +132,6 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||
conn.Close()
|
||||
}
|
||||
}()
|
||||
// Send connection preface to server.
|
||||
n, err := conn.Write(clientPreface)
|
||||
if err != nil {
|
||||
return nil, ConnectionErrorf("transport: %v", err)
|
||||
}
|
||||
if n != len(clientPreface) {
|
||||
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
|
||||
}
|
||||
framer := newFramer(conn)
|
||||
if initialWindowSize != defaultWindowSize {
|
||||
err = framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)})
|
||||
} else {
|
||||
err = framer.writeSettings(true)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, ConnectionErrorf("transport: %v", err)
|
||||
}
|
||||
// Adjust the connection flow control window if needed.
|
||||
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
|
||||
if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
|
||||
return nil, ConnectionErrorf("transport: %v", err)
|
||||
}
|
||||
}
|
||||
ua := primaryUA
|
||||
if opts.UserAgent != "" {
|
||||
ua = opts.UserAgent + " " + ua
|
||||
|
|
@ -178,7 +147,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||
writableChan: make(chan int, 1),
|
||||
shutdownChan: make(chan struct{}),
|
||||
errorChan: make(chan struct{}),
|
||||
framer: framer,
|
||||
framer: newFramer(conn),
|
||||
hBuf: &buf,
|
||||
hEnc: hpack.NewEncoder(&buf),
|
||||
controlBuf: newRecvBuffer(),
|
||||
|
|
@ -187,17 +156,42 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||
scheme: scheme,
|
||||
state: reachable,
|
||||
activeStreams: make(map[uint32]*Stream),
|
||||
authCreds: opts.AuthOptions,
|
||||
creds: opts.PerRPCCredentials,
|
||||
maxStreams: math.MaxInt32,
|
||||
streamSendQuota: defaultWindowSize,
|
||||
}
|
||||
// Start the reader goroutine for incoming message. Each transport has
|
||||
// a dedicated goroutine which reads HTTP2 frame from network. Then it
|
||||
// dispatches the frame to the corresponding stream entity.
|
||||
go t.reader()
|
||||
// Send connection preface to server.
|
||||
n, err := t.conn.Write(clientPreface)
|
||||
if err != nil {
|
||||
t.Close()
|
||||
return nil, ConnectionErrorf("transport: %v", err)
|
||||
}
|
||||
if n != len(clientPreface) {
|
||||
t.Close()
|
||||
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
|
||||
}
|
||||
if initialWindowSize != defaultWindowSize {
|
||||
err = t.framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)})
|
||||
} else {
|
||||
err = t.framer.writeSettings(true)
|
||||
}
|
||||
if err != nil {
|
||||
t.Close()
|
||||
return nil, ConnectionErrorf("transport: %v", err)
|
||||
}
|
||||
// Adjust the connection flow control window if needed.
|
||||
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
|
||||
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
|
||||
t.Close()
|
||||
return nil, ConnectionErrorf("transport: %v", err)
|
||||
}
|
||||
}
|
||||
go t.controller()
|
||||
t.writableChan <- 0
|
||||
// Start the reader goroutine for incoming message. The threading model
|
||||
// on receiving is that each transport has a dedicated goroutine which
|
||||
// reads HTTP2 frame from network. Then it dispatches the frame to the
|
||||
// corresponding stream entity.
|
||||
go t.reader()
|
||||
return t, nil
|
||||
}
|
||||
|
||||
|
|
@ -247,7 +241,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||
}
|
||||
ctx = peer.NewContext(ctx, pr)
|
||||
authData := make(map[string]string)
|
||||
for _, c := range t.authCreds {
|
||||
for _, c := range t.creds {
|
||||
// Construct URI required to get auth request metadata.
|
||||
var port string
|
||||
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
|
||||
|
|
@ -270,6 +264,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||
}
|
||||
}
|
||||
t.mu.Lock()
|
||||
if t.activeStreams == nil {
|
||||
t.mu.Unlock()
|
||||
return nil, ErrConnClosing
|
||||
}
|
||||
if t.state != reachable {
|
||||
t.mu.Unlock()
|
||||
return nil, ErrConnClosing
|
||||
|
|
@ -287,7 +285,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||
}
|
||||
}
|
||||
if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
|
||||
// t.streamsQuota will be updated when t.CloseStream is invoked.
|
||||
// Return the quota back now because there is no stream returned to the caller.
|
||||
if _, ok := err.(StreamError); ok && checkStreamsQuota {
|
||||
t.streamsQuota.add(1)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
t.mu.Lock()
|
||||
|
|
@ -339,6 +340,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||
if md, ok := metadata.FromContext(ctx); ok {
|
||||
hasMD = true
|
||||
for k, v := range md {
|
||||
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
|
||||
if isReservedHeader(k) {
|
||||
continue
|
||||
}
|
||||
for _, entry := range v {
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
|
||||
}
|
||||
|
|
@ -388,9 +393,19 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||
func (t *http2Client) CloseStream(s *Stream, err error) {
|
||||
var updateStreams bool
|
||||
t.mu.Lock()
|
||||
if t.activeStreams == nil {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if t.streamsQuota != nil {
|
||||
updateStreams = true
|
||||
}
|
||||
if t.state == draining && len(t.activeStreams) == 1 {
|
||||
// The transport is draining and s is the last live stream on t.
|
||||
t.mu.Unlock()
|
||||
t.Close()
|
||||
return
|
||||
}
|
||||
delete(t.activeStreams, s.id)
|
||||
t.mu.Unlock()
|
||||
if updateStreams {
|
||||
|
|
@ -427,9 +442,12 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
|
|||
// accessed any more.
|
||||
func (t *http2Client) Close() (err error) {
|
||||
t.mu.Lock()
|
||||
if t.state == reachable {
|
||||
close(t.errorChan)
|
||||
}
|
||||
if t.state == closing {
|
||||
t.mu.Unlock()
|
||||
return errors.New("transport: Close() was already called")
|
||||
return
|
||||
}
|
||||
t.state = closing
|
||||
t.mu.Unlock()
|
||||
|
|
@ -452,6 +470,25 @@ func (t *http2Client) Close() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (t *http2Client) GracefulClose() error {
|
||||
t.mu.Lock()
|
||||
if t.state == closing {
|
||||
t.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
if t.state == draining {
|
||||
t.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
t.state = draining
|
||||
active := len(t.activeStreams)
|
||||
t.mu.Unlock()
|
||||
if active == 0 {
|
||||
return t.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller
|
||||
// should proceed only if Write returns nil.
|
||||
// TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later
|
||||
|
|
@ -574,6 +611,11 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
|
|||
// Window updates will deliver to the controller for sending when
|
||||
// the cumulative quota exceeds the corresponding threshold.
|
||||
func (t *http2Client) updateWindow(s *Stream, n uint32) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.state == streamDone {
|
||||
return
|
||||
}
|
||||
if w := t.fc.onRead(n); w > 0 {
|
||||
t.controlBuf.put(&windowUpdate{0, w})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -303,6 +303,11 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
|
|||
// Window updates will deliver to the controller for sending when
|
||||
// the cumulative quota exceeds the corresponding threshold.
|
||||
func (t *http2Server) updateWindow(s *Stream, n uint32) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.state == streamDone {
|
||||
return
|
||||
}
|
||||
if w := t.fc.onRead(n); w > 0 {
|
||||
t.controlBuf.put(&windowUpdate{0, w})
|
||||
}
|
||||
|
|
@ -455,6 +460,10 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
|
|||
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
|
||||
}
|
||||
for k, v := range md {
|
||||
if isReservedHeader(k) {
|
||||
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
|
||||
continue
|
||||
}
|
||||
for _, entry := range v {
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
|
||||
}
|
||||
|
|
@ -497,6 +506,10 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
|
|||
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc})
|
||||
// Attach the trailer metadata.
|
||||
for k, v := range s.trailer {
|
||||
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
|
||||
if isReservedHeader(k) {
|
||||
continue
|
||||
}
|
||||
for _, entry := range v {
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -127,16 +127,40 @@ func isReservedHeader(hdr string) bool {
|
|||
}
|
||||
}
|
||||
|
||||
// isWhitelistedPseudoHeader checks whether hdr belongs to HTTP2 pseudoheaders
|
||||
// that should be propagated into metadata visible to users.
|
||||
func isWhitelistedPseudoHeader(hdr string) bool {
|
||||
switch hdr {
|
||||
case ":authority":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (d *decodeState) setErr(err error) {
|
||||
if d.err == nil {
|
||||
d.err = err
|
||||
}
|
||||
}
|
||||
|
||||
func validContentType(t string) bool {
|
||||
e := "application/grpc"
|
||||
if !strings.HasPrefix(t, e) {
|
||||
return false
|
||||
}
|
||||
// Support variations on the content-type
|
||||
// (e.g. "application/grpc+blah", "application/grpc;blah").
|
||||
if len(t) > len(e) && t[len(e)] != '+' && t[len(e)] != ';' {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *decodeState) processHeaderField(f hpack.HeaderField) {
|
||||
switch f.Name {
|
||||
case "content-type":
|
||||
if !strings.Contains(f.Value, "application/grpc") {
|
||||
if !validContentType(f.Value) {
|
||||
d.setErr(StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value))
|
||||
return
|
||||
}
|
||||
|
|
@ -162,7 +186,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
|
|||
case ":path":
|
||||
d.method = f.Value
|
||||
default:
|
||||
if !isReservedHeader(f.Name) {
|
||||
if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) {
|
||||
if f.Name == "user-agent" {
|
||||
i := strings.LastIndex(f.Value, " ")
|
||||
if i == -1 {
|
||||
|
|
|
|||
|
|
@ -321,6 +321,7 @@ const (
|
|||
reachable transportState = iota
|
||||
unreachable
|
||||
closing
|
||||
draining
|
||||
)
|
||||
|
||||
// NewServerTransport creates a ServerTransport with conn or non-nil error
|
||||
|
|
@ -335,9 +336,11 @@ type ConnectOptions struct {
|
|||
UserAgent string
|
||||
// Dialer specifies how to dial a network address.
|
||||
Dialer func(string, time.Duration) (net.Conn, error)
|
||||
// AuthOptions stores the credentials required to setup a client connection and/or issue RPCs.
|
||||
AuthOptions []credentials.Credentials
|
||||
// Timeout specifies the timeout for dialing a client connection.
|
||||
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
|
||||
PerRPCCredentials []credentials.PerRPCCredentials
|
||||
// TransportCredentials stores the Authenticator required to setup a client connection.
|
||||
TransportCredentials credentials.TransportCredentials
|
||||
// Timeout specifies the timeout for dialing a ClientTransport.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
|
|
@ -391,6 +394,10 @@ type ClientTransport interface {
|
|||
// is called only once.
|
||||
Close() error
|
||||
|
||||
// GracefulClose starts to tear down the transport. It stops accepting
|
||||
// new RPCs and wait the completion of the pending RPCs.
|
||||
GracefulClose() error
|
||||
|
||||
// Write sends the data for the given stream. A nil stream indicates
|
||||
// the write is to be performed on the transport as a whole.
|
||||
Write(s *Stream, data []byte, opts *Options) error
|
||||
|
|
@ -468,7 +475,7 @@ func (e ConnectionError) Error() string {
|
|||
return fmt.Sprintf("connection error: desc = %q", e.Desc)
|
||||
}
|
||||
|
||||
// Define some common ConnectionErrors.
|
||||
// ErrConnClosing indicates that the transport is closing.
|
||||
var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
|
||||
|
||||
// StreamError is an error that only affects one stream within a connection.
|
||||
|
|
|
|||
Loading…
Reference in New Issue