feat: add grpc mux transport (#1602)

Signed-off-by: Jim Ma <majinjing3@gmail.com>
This commit is contained in:
Jim Ma 2022-08-30 14:20:43 +08:00 committed by Gaius
parent 9ab33635c5
commit 98fb1fc427
No known key found for this signature in database
GPG Key ID: 8B4E5D1290FA2FFB
11 changed files with 315 additions and 13 deletions

View File

@ -20,12 +20,14 @@ import (
"crypto/tls" "crypto/tls"
"github.com/johanbrandhorst/certify" "github.com/johanbrandhorst/certify"
"d7y.io/dragonfly/v2/pkg/net/ip"
) )
func GetCertificate(certifyClient *certify.Certify) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { func GetCertificate(certifyClient *certify.Certify) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
// FIXME peers need pure ip cert, certify checks the ServerName, so workaround here // FIXME peers need pure ip cert, certify checks the ServerName, so workaround here
hello.ServerName = "peer" hello.ServerName = ip.IPv4
return certifyClient.GetCertificate(hello) return certifyClient.GetCertificate(hello)
} }
} }

View File

@ -174,11 +174,26 @@ func ConvertPattern(p string, defaultPattern commonv1.Pattern) commonv1.Pattern
} }
type GlobalSecurityOption struct { type GlobalSecurityOption struct {
// AutoIssueCert indicates to issue client certificates for all grpc call
// if AutoIssueCert is false, any other option in Security will be ignored
AutoIssueCert bool `mapstructure:"autoIssueCert" yaml:"autoIssueCert"` AutoIssueCert bool `mapstructure:"autoIssueCert" yaml:"autoIssueCert"`
// CACert is the root CA certificate for all grpc tls handshake, it can be path or PEM format string
CACert serialize.PEMContent `mapstructure:"caCert" yaml:"caCert"` CACert serialize.PEMContent `mapstructure:"caCert" yaml:"caCert"`
// TLSPrefer indicates to verify client certificates for grpc ServerHandshake
TLSVerify bool `mapstructure:"tlsVerify" yaml:"tlsVerify"` TLSVerify bool `mapstructure:"tlsVerify" yaml:"tlsVerify"`
// TLSPolicy controls the grpc shandshake behaviors:
// force: both ClientHandshake and ServerHandshake are only support tls
// prefer: ServerHandshake supports tls and insecure (non-tls), ClientHandshake will only support tls
// default or empty: ServerHandshake supports tls and insecure (non-tls), ClientHandshake will only support insecure (non-tls)
TLSPolicy string `mapstructure:"tlsPolicy" yaml:"tlsPolicy"`
} }
const (
TLSPolicyForce = "force"
TLSPolicyPrefer = "prefer"
TLSPolicyDefault = "default"
)
type SchedulerOption struct { type SchedulerOption struct {
// Manager is to get the scheduler configuration remotely. // Manager is to get the scheduler configuration remotely.
Manager ManagerOption `mapstructure:"manager" yaml:"manager"` Manager ManagerOption `mapstructure:"manager" yaml:"manager"`

View File

@ -81,7 +81,7 @@ var peerHostConfig = func() *DaemonOption {
DownloadGRPC: ListenOption{ DownloadGRPC: ListenOption{
Security: SecurityOption{ Security: SecurityOption{
Insecure: true, Insecure: true,
TLSVerify: true, TLSVerify: false,
}, },
UnixListen: &UnixListenOption{}, UnixListen: &UnixListenOption{},
}, },
@ -176,5 +176,11 @@ var peerHostConfig = func() *DaemonOption {
Duration: time.Minute, Duration: time.Minute,
}, },
}, },
Security: GlobalSecurityOption{
AutoIssueCert: false,
CACert: serialize.PEMContent(""),
TLSVerify: false,
TLSPolicy: TLSPolicyDefault,
},
} }
} }

View File

@ -29,6 +29,7 @@ import (
"d7y.io/dragonfly/v2/pkg/dfnet" "d7y.io/dragonfly/v2/pkg/dfnet"
"d7y.io/dragonfly/v2/pkg/net/fqdn" "d7y.io/dragonfly/v2/pkg/net/fqdn"
"d7y.io/dragonfly/v2/pkg/net/ip" "d7y.io/dragonfly/v2/pkg/net/ip"
"d7y.io/dragonfly/v2/pkg/serialize"
) )
var peerHostConfig = func() *DaemonOption { var peerHostConfig = func() *DaemonOption {
@ -81,7 +82,7 @@ var peerHostConfig = func() *DaemonOption {
DownloadGRPC: ListenOption{ DownloadGRPC: ListenOption{
Security: SecurityOption{ Security: SecurityOption{
Insecure: true, Insecure: true,
TLSVerify: true, TLSVerify: false,
}, },
UnixListen: &UnixListenOption{}, UnixListen: &UnixListenOption{},
}, },
@ -175,5 +176,11 @@ var peerHostConfig = func() *DaemonOption {
Duration: time.Minute, Duration: time.Minute,
}, },
}, },
Security: GlobalSecurityOption{
AutoIssueCert: false,
CACert: serialize.PEMContent(""),
TLSVerify: false,
TLSPolicy: TLSPolicyDefault,
},
} }
} }

View File

@ -145,13 +145,15 @@ func New(opt *config.DaemonOption, d dfpath.Dfpath) (Daemon, error) {
} }
// issue a certificate to reduce first time delay // issue a certificate to reduce first time delay
_, err := certifyClient.GetCertificate(&tls.ClientHelloInfo{ cert, err := certifyClient.GetCertificate(&tls.ClientHelloInfo{
ServerName: ip.IPv4, ServerName: ip.IPv4,
}) })
if err != nil { if err != nil {
logger.Errorf("issue certificate error: %s", err.Error()) logger.Errorf("issue certificate error: %s", err.Error())
return nil, err return nil, err
} }
logger.Debugf("request cert from manager, common name: %s, issuer: %s",
cert.Leaf.Subject.CommonName, cert.Leaf.Issuer.CommonName)
} }
// New dynconfig manager client. // New dynconfig manager client.
@ -337,6 +339,7 @@ func loadGPRCTLSCredentials(opt config.SecurityOption, certifyClient *certify.Ce
} }
opt.TLSConfig.ClientCAs = certPool opt.TLSConfig.ClientCAs = certPool
opt.TLSConfig.RootCAs = certPool
// Load server's certificate and private key // Load server's certificate and private key
if certifyClient == nil { if certifyClient == nil {
@ -347,7 +350,6 @@ func loadGPRCTLSCredentials(opt config.SecurityOption, certifyClient *certify.Ce
opt.TLSConfig.Certificates = []tls.Certificate{serverCert} opt.TLSConfig.Certificates = []tls.Certificate{serverCert}
} else { } else {
// enable auto issue certificate // enable auto issue certificate
opt.TLSConfig.Certificates = nil
opt.TLSConfig.GetCertificate = config.GetCertificate(certifyClient) opt.TLSConfig.GetCertificate = config.GetCertificate(certifyClient)
opt.TLSConfig.GetClientCertificate = certifyClient.GetClientCertificate opt.TLSConfig.GetClientCertificate = certifyClient.GetClientCertificate
} }
@ -356,7 +358,15 @@ func loadGPRCTLSCredentials(opt config.SecurityOption, certifyClient *certify.Ce
opt.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert opt.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
} }
switch security.TLSPolicy {
case config.TLSPolicyDefault, config.TLSPolicyPrefer:
return rpc.NewMuxTransportCredentials(opt.TLSConfig,
rpc.WithTLSPreferClientHandshake(security.TLSPolicy == config.TLSPolicyPrefer)), nil
case config.TLSPolicyForce:
return credentials.NewTLS(opt.TLSConfig), nil return credentials.NewTLS(opt.TLSConfig), nil
default:
return nil, fmt.Errorf("invalid tlsPolicy: %s", security.TLSPolicy)
}
} }
func loadGlobalGPRCTLSCredentials(certifyClient *certify.Certify, security config.GlobalSecurityOption) (credentials.TransportCredentials, error) { func loadGlobalGPRCTLSCredentials(certifyClient *certify.Certify, security config.GlobalSecurityOption) (credentials.TransportCredentials, error) {
@ -370,17 +380,26 @@ func loadGlobalGPRCTLSCredentials(certifyClient *certify.Certify, security confi
return nil, fmt.Errorf("failed to add global CA's certificate") return nil, fmt.Errorf("failed to add global CA's certificate")
} }
config := &tls.Config{ tlsConfig := &tls.Config{
ClientCAs: certPool, ClientCAs: certPool,
RootCAs: certPool,
GetCertificate: config.GetCertificate(certifyClient), GetCertificate: config.GetCertificate(certifyClient),
GetClientCertificate: certifyClient.GetClientCertificate, GetClientCertificate: certifyClient.GetClientCertificate,
} }
if security.TLSVerify { if security.TLSVerify {
config.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
} }
return credentials.NewTLS(config), nil switch security.TLSPolicy {
case config.TLSPolicyDefault, config.TLSPolicyPrefer:
return rpc.NewMuxTransportCredentials(tlsConfig,
rpc.WithTLSPreferClientHandshake(security.TLSPolicy == config.TLSPolicyPrefer)), nil
case config.TLSPolicyForce:
return credentials.NewTLS(tlsConfig), nil
default:
return nil, fmt.Errorf("invalid tlsPolicy: %s", security.TLSPolicy)
}
} }
func (*clientDaemon) prepareTCPListener(opt config.ListenOption, withTLS bool) (net.Listener, int, error) { func (*clientDaemon) prepareTCPListener(opt config.ListenOption, withTLS bool) (net.Listener, int, error) {

View File

@ -485,7 +485,7 @@ func (pm *pieceManager) downloadKnownLengthSource(ctx context.Context, pt Task,
pt.ReportPieceResult(request, result, nil) pt.ReportPieceResult(request, result, nil)
pt.PublishPieceInfo(pieceNum, uint32(result.Size)) pt.PublishPieceInfo(pieceNum, uint32(result.Size))
if supportConcurrent { if supportConcurrent && pieceNum+2 < maxPieceNum {
// the time unit of FinishTime and BeginTime is ns // the time unit of FinishTime and BeginTime is ns
speed := float64(pieceSize) / float64((result.FinishTime-result.BeginTime)/1000000) speed := float64(pieceSize) / float64((result.FinishTime-result.BeginTime)/1000000)
if speed < float64(pm.concurrentOption.ThresholdSpeed) { if speed < float64(pm.concurrentOption.ThresholdSpeed) {

View File

@ -38,6 +38,8 @@ import (
"go.uber.org/atomic" "go.uber.org/atomic"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
grpcpeer "google.golang.org/grpc/peer"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
@ -211,7 +213,28 @@ func (s *server) sendFirstPieceTasks(
return sendExistPieces(sync.Context(), log, get, request, sync, sentMap, false) return sendExistPieces(sync.Context(), log, get, request, sync, sentMap, false)
} }
func printAuthInfo(ctx context.Context) {
if peerInfo, ok := grpcpeer.FromContext(ctx); ok {
if tlsInfo, ok := peerInfo.AuthInfo.(credentials.TLSInfo); ok {
for i, pc := range tlsInfo.State.PeerCertificates {
logger.Debugf("peer cert depth %d, issuer: %#v", i, pc.Issuer.CommonName)
logger.Debugf("peer cert depth %d, common name: %#v", i, pc.Subject.CommonName)
if len(pc.IPAddresses) > 0 {
logger.Debugf("peer cert depth %d, ip: %#v", i, pc.IPAddresses)
}
if len(pc.DNSNames) > 0 {
logger.Debugf("peer cert depth %d, dns: %#v", i, pc.DNSNames)
}
}
}
}
}
func (s *server) SyncPieceTasks(sync dfdaemonv1.Daemon_SyncPieceTasksServer) error { func (s *server) SyncPieceTasks(sync dfdaemonv1.Daemon_SyncPieceTasksServer) error {
if logger.IsDebug() {
printAuthInfo(sync.Context())
}
request, err := sync.Recv() request, err := sync.Recv()
if err != nil { if err != nil {
logger.Errorf("receive first sync piece tasks request error: %s", err.Error()) logger.Errorf("receive first sync piece tasks request error: %s", err.Error())

View File

@ -52,6 +52,10 @@ func (s *seeder) SyncPieceTasks(tasksServer cdnsystemv1.Seeder_SyncPieceTasksSer
} }
func (s *seeder) ObtainSeeds(seedRequest *cdnsystemv1.SeedRequest, seedsServer cdnsystemv1.Seeder_ObtainSeedsServer) error { func (s *seeder) ObtainSeeds(seedRequest *cdnsystemv1.SeedRequest, seedsServer cdnsystemv1.Seeder_ObtainSeedsServer) error {
if logger.IsDebug() {
printAuthInfo(seedsServer.Context())
}
metrics.SeedPeerConcurrentDownloadGauge.Inc() metrics.SeedPeerConcurrentDownloadGauge.Inc()
defer metrics.SeedPeerConcurrentDownloadGauge.Dec() defer metrics.SeedPeerConcurrentDownloadGauge.Dec()
metrics.SeedPeerDownloadCount.Add(1) metrics.SeedPeerDownloadCount.Add(1)

View File

@ -215,6 +215,10 @@ func (log *SugaredLoggerOnWith) Debug(args ...any) {
CoreLogger.Debugw(fmt.Sprint(args...), log.withArgs...) CoreLogger.Debugw(fmt.Sprint(args...), log.withArgs...)
} }
func (log *SugaredLoggerOnWith) IsDebug() bool {
return coreLogLevelEnabler.Enabled(zap.DebugLevel)
}
func Infof(template string, args ...any) { func Infof(template string, args ...any) {
CoreLogger.Infof(template, args...) CoreLogger.Infof(template, args...)
} }
@ -243,6 +247,14 @@ func Debugf(template string, args ...any) {
CoreLogger.Debugf(template, args...) CoreLogger.Debugf(template, args...)
} }
func Debug(args ...any) {
CoreLogger.Debug(args...)
}
func IsDebug() bool {
return coreLogLevelEnabler.Enabled(zap.DebugLevel)
}
func Fatalf(template string, args ...any) { func Fatalf(template string, args ...any) {
CoreLogger.Fatalf(template, args...) CoreLogger.Fatalf(template, args...)
} }

145
pkg/rpc/mux.go Normal file
View File

@ -0,0 +1,145 @@
/*
* Copyright 2022 The Dragonfly Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rpc
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"net"
"github.com/soheilhy/cmux"
"google.golang.org/grpc/credentials"
)
const (
// cmux's TLS matcher read at least 3 + 1 bytes
tlsRecordPrefix = 4
)
type muxTransportCredentials struct {
credentials credentials.TransportCredentials
tlsMatcher cmux.Matcher
tlsPrefer bool
}
func WithTLSPreferClientHandshake(prefer bool) func(m *muxTransportCredentials) {
return func(m *muxTransportCredentials) {
m.tlsPrefer = prefer
}
}
func NewMuxTransportCredentials(c *tls.Config, opts ...func(m *muxTransportCredentials)) credentials.TransportCredentials {
m := &muxTransportCredentials{
tlsMatcher: cmux.TLS(),
credentials: credentials.NewTLS(c),
}
for _, opt := range opts {
opt(m)
}
return m
}
func (m *muxTransportCredentials) ClientHandshake(ctx context.Context, s string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
if m.tlsPrefer {
return m.credentials.ClientHandshake(ctx, s, conn)
}
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
}
func (m *muxTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
var prefix = make([]byte, tlsRecordPrefix)
n, err := conn.Read(prefix)
if err != nil {
return nil, nil, err
}
if n != tlsRecordPrefix {
_ = conn.Close()
return nil, nil, fmt.Errorf("short read handshake")
}
conn = &muxConn{
Conn: conn,
buf: prefix,
}
// tls
if m.tlsMatcher(bytes.NewReader(prefix)) {
return m.credentials.ServerHandshake(conn)
}
// non-tls
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
}
func (m *muxTransportCredentials) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{
ProtocolVersion: "",
SecurityProtocol: "mux",
ServerName: "",
}
}
func (m *muxTransportCredentials) Clone() credentials.TransportCredentials {
return &muxTransportCredentials{
tlsMatcher: cmux.TLS(),
credentials: m.credentials.Clone(),
}
}
func (m *muxTransportCredentials) OverrideServerName(s string) error {
return m.credentials.OverrideServerName(s)
}
// info contains the auth information for an insecure connection.
// It implements the AuthInfo interface.
type info struct {
credentials.CommonAuthInfo
}
// AuthType returns the type of info as a string.
func (info) AuthType() string {
return "insecure"
}
type muxConn struct {
net.Conn
buf []byte
}
func (m *muxConn) Read(b []byte) (int, error) {
if len(m.buf) == 0 {
return m.Conn.Read(b)
}
wn := copy(b, m.buf)
if wn < len(m.buf) {
m.buf = m.buf[wn:]
return wn, nil
}
m.buf = nil
b = b[wn:]
n, err := m.Conn.Read(b)
return n + wn, err
}

69
pkg/rpc/mux_test.go Normal file
View File

@ -0,0 +1,69 @@
/*
* Copyright 2022 The Dragonfly Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package rpc
import (
"bytes"
"io/ioutil"
"net"
"testing"
testifyassert "github.com/stretchr/testify/assert"
)
type testConn struct {
net.Conn
buf *bytes.Buffer
}
func (conn *testConn) Read(b []byte) (int, error) {
return conn.buf.Read(b)
}
func Test_muxConn(t *testing.T) {
assert := testifyassert.New(t)
testCases := []struct {
name string
bufSize int
data []byte
}{
{
name: "buf size equal data size",
bufSize: 4,
data: []byte("hell"),
},
{
name: "buf size less than data size",
bufSize: 4,
data: []byte("hello world"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mc := &muxConn{
buf: tc.data[:tc.bufSize],
Conn: &testConn{Conn: nil, buf: bytes.NewBuffer(tc.data[tc.bufSize:])},
}
data, err := ioutil.ReadAll(mc)
assert.Nil(err, "read all should ok")
assert.Equal(tc.data, data, "data shloud be same")
})
}
}