feat: add grpc mux transport (#1602)
Signed-off-by: Jim Ma <majinjing3@gmail.com>
This commit is contained in:
parent
9ab33635c5
commit
98fb1fc427
|
|
@ -20,12 +20,14 @@ import (
|
|||
"crypto/tls"
|
||||
|
||||
"github.com/johanbrandhorst/certify"
|
||||
|
||||
"d7y.io/dragonfly/v2/pkg/net/ip"
|
||||
)
|
||||
|
||||
func GetCertificate(certifyClient *certify.Certify) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// FIXME peers need pure ip cert, certify checks the ServerName, so workaround here
|
||||
hello.ServerName = "peer"
|
||||
hello.ServerName = ip.IPv4
|
||||
return certifyClient.GetCertificate(hello)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -174,11 +174,26 @@ func ConvertPattern(p string, defaultPattern commonv1.Pattern) commonv1.Pattern
|
|||
}
|
||||
|
||||
type GlobalSecurityOption struct {
|
||||
// AutoIssueCert indicates to issue client certificates for all grpc call
|
||||
// if AutoIssueCert is false, any other option in Security will be ignored
|
||||
AutoIssueCert bool `mapstructure:"autoIssueCert" yaml:"autoIssueCert"`
|
||||
// CACert is the root CA certificate for all grpc tls handshake, it can be path or PEM format string
|
||||
CACert serialize.PEMContent `mapstructure:"caCert" yaml:"caCert"`
|
||||
// TLSPrefer indicates to verify client certificates for grpc ServerHandshake
|
||||
TLSVerify bool `mapstructure:"tlsVerify" yaml:"tlsVerify"`
|
||||
// TLSPolicy controls the grpc shandshake behaviors:
|
||||
// force: both ClientHandshake and ServerHandshake are only support tls
|
||||
// prefer: ServerHandshake supports tls and insecure (non-tls), ClientHandshake will only support tls
|
||||
// default or empty: ServerHandshake supports tls and insecure (non-tls), ClientHandshake will only support insecure (non-tls)
|
||||
TLSPolicy string `mapstructure:"tlsPolicy" yaml:"tlsPolicy"`
|
||||
}
|
||||
|
||||
const (
|
||||
TLSPolicyForce = "force"
|
||||
TLSPolicyPrefer = "prefer"
|
||||
TLSPolicyDefault = "default"
|
||||
)
|
||||
|
||||
type SchedulerOption struct {
|
||||
// Manager is to get the scheduler configuration remotely.
|
||||
Manager ManagerOption `mapstructure:"manager" yaml:"manager"`
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ var peerHostConfig = func() *DaemonOption {
|
|||
DownloadGRPC: ListenOption{
|
||||
Security: SecurityOption{
|
||||
Insecure: true,
|
||||
TLSVerify: true,
|
||||
TLSVerify: false,
|
||||
},
|
||||
UnixListen: &UnixListenOption{},
|
||||
},
|
||||
|
|
@ -176,5 +176,11 @@ var peerHostConfig = func() *DaemonOption {
|
|||
Duration: time.Minute,
|
||||
},
|
||||
},
|
||||
Security: GlobalSecurityOption{
|
||||
AutoIssueCert: false,
|
||||
CACert: serialize.PEMContent(""),
|
||||
TLSVerify: false,
|
||||
TLSPolicy: TLSPolicyDefault,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ import (
|
|||
"d7y.io/dragonfly/v2/pkg/dfnet"
|
||||
"d7y.io/dragonfly/v2/pkg/net/fqdn"
|
||||
"d7y.io/dragonfly/v2/pkg/net/ip"
|
||||
"d7y.io/dragonfly/v2/pkg/serialize"
|
||||
)
|
||||
|
||||
var peerHostConfig = func() *DaemonOption {
|
||||
|
|
@ -81,7 +82,7 @@ var peerHostConfig = func() *DaemonOption {
|
|||
DownloadGRPC: ListenOption{
|
||||
Security: SecurityOption{
|
||||
Insecure: true,
|
||||
TLSVerify: true,
|
||||
TLSVerify: false,
|
||||
},
|
||||
UnixListen: &UnixListenOption{},
|
||||
},
|
||||
|
|
@ -175,5 +176,11 @@ var peerHostConfig = func() *DaemonOption {
|
|||
Duration: time.Minute,
|
||||
},
|
||||
},
|
||||
Security: GlobalSecurityOption{
|
||||
AutoIssueCert: false,
|
||||
CACert: serialize.PEMContent(""),
|
||||
TLSVerify: false,
|
||||
TLSPolicy: TLSPolicyDefault,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -145,13 +145,15 @@ func New(opt *config.DaemonOption, d dfpath.Dfpath) (Daemon, error) {
|
|||
}
|
||||
|
||||
// issue a certificate to reduce first time delay
|
||||
_, err := certifyClient.GetCertificate(&tls.ClientHelloInfo{
|
||||
cert, err := certifyClient.GetCertificate(&tls.ClientHelloInfo{
|
||||
ServerName: ip.IPv4,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Errorf("issue certificate error: %s", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
logger.Debugf("request cert from manager, common name: %s, issuer: %s",
|
||||
cert.Leaf.Subject.CommonName, cert.Leaf.Issuer.CommonName)
|
||||
}
|
||||
|
||||
// New dynconfig manager client.
|
||||
|
|
@ -337,6 +339,7 @@ func loadGPRCTLSCredentials(opt config.SecurityOption, certifyClient *certify.Ce
|
|||
}
|
||||
|
||||
opt.TLSConfig.ClientCAs = certPool
|
||||
opt.TLSConfig.RootCAs = certPool
|
||||
|
||||
// Load server's certificate and private key
|
||||
if certifyClient == nil {
|
||||
|
|
@ -347,7 +350,6 @@ func loadGPRCTLSCredentials(opt config.SecurityOption, certifyClient *certify.Ce
|
|||
opt.TLSConfig.Certificates = []tls.Certificate{serverCert}
|
||||
} else {
|
||||
// enable auto issue certificate
|
||||
opt.TLSConfig.Certificates = nil
|
||||
opt.TLSConfig.GetCertificate = config.GetCertificate(certifyClient)
|
||||
opt.TLSConfig.GetClientCertificate = certifyClient.GetClientCertificate
|
||||
}
|
||||
|
|
@ -356,7 +358,15 @@ func loadGPRCTLSCredentials(opt config.SecurityOption, certifyClient *certify.Ce
|
|||
opt.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
switch security.TLSPolicy {
|
||||
case config.TLSPolicyDefault, config.TLSPolicyPrefer:
|
||||
return rpc.NewMuxTransportCredentials(opt.TLSConfig,
|
||||
rpc.WithTLSPreferClientHandshake(security.TLSPolicy == config.TLSPolicyPrefer)), nil
|
||||
case config.TLSPolicyForce:
|
||||
return credentials.NewTLS(opt.TLSConfig), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid tlsPolicy: %s", security.TLSPolicy)
|
||||
}
|
||||
}
|
||||
|
||||
func loadGlobalGPRCTLSCredentials(certifyClient *certify.Certify, security config.GlobalSecurityOption) (credentials.TransportCredentials, error) {
|
||||
|
|
@ -370,17 +380,26 @@ func loadGlobalGPRCTLSCredentials(certifyClient *certify.Certify, security confi
|
|||
return nil, fmt.Errorf("failed to add global CA's certificate")
|
||||
}
|
||||
|
||||
config := &tls.Config{
|
||||
tlsConfig := &tls.Config{
|
||||
ClientCAs: certPool,
|
||||
RootCAs: certPool,
|
||||
GetCertificate: config.GetCertificate(certifyClient),
|
||||
GetClientCertificate: certifyClient.GetClientCertificate,
|
||||
}
|
||||
|
||||
if security.TLSVerify {
|
||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
return credentials.NewTLS(config), nil
|
||||
switch security.TLSPolicy {
|
||||
case config.TLSPolicyDefault, config.TLSPolicyPrefer:
|
||||
return rpc.NewMuxTransportCredentials(tlsConfig,
|
||||
rpc.WithTLSPreferClientHandshake(security.TLSPolicy == config.TLSPolicyPrefer)), nil
|
||||
case config.TLSPolicyForce:
|
||||
return credentials.NewTLS(tlsConfig), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid tlsPolicy: %s", security.TLSPolicy)
|
||||
}
|
||||
}
|
||||
|
||||
func (*clientDaemon) prepareTCPListener(opt config.ListenOption, withTLS bool) (net.Listener, int, error) {
|
||||
|
|
|
|||
|
|
@ -485,7 +485,7 @@ func (pm *pieceManager) downloadKnownLengthSource(ctx context.Context, pt Task,
|
|||
|
||||
pt.ReportPieceResult(request, result, nil)
|
||||
pt.PublishPieceInfo(pieceNum, uint32(result.Size))
|
||||
if supportConcurrent {
|
||||
if supportConcurrent && pieceNum+2 < maxPieceNum {
|
||||
// the time unit of FinishTime and BeginTime is ns
|
||||
speed := float64(pieceSize) / float64((result.FinishTime-result.BeginTime)/1000000)
|
||||
if speed < float64(pm.concurrentOption.ThresholdSpeed) {
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ import (
|
|||
"go.uber.org/atomic"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
grpcpeer "google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
|
|
@ -211,7 +213,28 @@ func (s *server) sendFirstPieceTasks(
|
|||
return sendExistPieces(sync.Context(), log, get, request, sync, sentMap, false)
|
||||
}
|
||||
|
||||
func printAuthInfo(ctx context.Context) {
|
||||
if peerInfo, ok := grpcpeer.FromContext(ctx); ok {
|
||||
if tlsInfo, ok := peerInfo.AuthInfo.(credentials.TLSInfo); ok {
|
||||
for i, pc := range tlsInfo.State.PeerCertificates {
|
||||
logger.Debugf("peer cert depth %d, issuer: %#v", i, pc.Issuer.CommonName)
|
||||
logger.Debugf("peer cert depth %d, common name: %#v", i, pc.Subject.CommonName)
|
||||
if len(pc.IPAddresses) > 0 {
|
||||
logger.Debugf("peer cert depth %d, ip: %#v", i, pc.IPAddresses)
|
||||
}
|
||||
if len(pc.DNSNames) > 0 {
|
||||
logger.Debugf("peer cert depth %d, dns: %#v", i, pc.DNSNames)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) SyncPieceTasks(sync dfdaemonv1.Daemon_SyncPieceTasksServer) error {
|
||||
if logger.IsDebug() {
|
||||
printAuthInfo(sync.Context())
|
||||
}
|
||||
|
||||
request, err := sync.Recv()
|
||||
if err != nil {
|
||||
logger.Errorf("receive first sync piece tasks request error: %s", err.Error())
|
||||
|
|
|
|||
|
|
@ -52,6 +52,10 @@ func (s *seeder) SyncPieceTasks(tasksServer cdnsystemv1.Seeder_SyncPieceTasksSer
|
|||
}
|
||||
|
||||
func (s *seeder) ObtainSeeds(seedRequest *cdnsystemv1.SeedRequest, seedsServer cdnsystemv1.Seeder_ObtainSeedsServer) error {
|
||||
if logger.IsDebug() {
|
||||
printAuthInfo(seedsServer.Context())
|
||||
}
|
||||
|
||||
metrics.SeedPeerConcurrentDownloadGauge.Inc()
|
||||
defer metrics.SeedPeerConcurrentDownloadGauge.Dec()
|
||||
metrics.SeedPeerDownloadCount.Add(1)
|
||||
|
|
|
|||
|
|
@ -215,6 +215,10 @@ func (log *SugaredLoggerOnWith) Debug(args ...any) {
|
|||
CoreLogger.Debugw(fmt.Sprint(args...), log.withArgs...)
|
||||
}
|
||||
|
||||
func (log *SugaredLoggerOnWith) IsDebug() bool {
|
||||
return coreLogLevelEnabler.Enabled(zap.DebugLevel)
|
||||
}
|
||||
|
||||
func Infof(template string, args ...any) {
|
||||
CoreLogger.Infof(template, args...)
|
||||
}
|
||||
|
|
@ -243,6 +247,14 @@ func Debugf(template string, args ...any) {
|
|||
CoreLogger.Debugf(template, args...)
|
||||
}
|
||||
|
||||
func Debug(args ...any) {
|
||||
CoreLogger.Debug(args...)
|
||||
}
|
||||
|
||||
func IsDebug() bool {
|
||||
return coreLogLevelEnabler.Enabled(zap.DebugLevel)
|
||||
}
|
||||
|
||||
func Fatalf(template string, args ...any) {
|
||||
CoreLogger.Fatalf(template, args...)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,145 @@
|
|||
/*
|
||||
* Copyright 2022 The Dragonfly Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/soheilhy/cmux"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
const (
|
||||
// cmux's TLS matcher read at least 3 + 1 bytes
|
||||
tlsRecordPrefix = 4
|
||||
)
|
||||
|
||||
type muxTransportCredentials struct {
|
||||
credentials credentials.TransportCredentials
|
||||
tlsMatcher cmux.Matcher
|
||||
tlsPrefer bool
|
||||
}
|
||||
|
||||
func WithTLSPreferClientHandshake(prefer bool) func(m *muxTransportCredentials) {
|
||||
return func(m *muxTransportCredentials) {
|
||||
m.tlsPrefer = prefer
|
||||
}
|
||||
}
|
||||
|
||||
func NewMuxTransportCredentials(c *tls.Config, opts ...func(m *muxTransportCredentials)) credentials.TransportCredentials {
|
||||
m := &muxTransportCredentials{
|
||||
tlsMatcher: cmux.TLS(),
|
||||
credentials: credentials.NewTLS(c),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *muxTransportCredentials) ClientHandshake(ctx context.Context, s string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
if m.tlsPrefer {
|
||||
return m.credentials.ClientHandshake(ctx, s, conn)
|
||||
}
|
||||
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
|
||||
}
|
||||
|
||||
func (m *muxTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
var prefix = make([]byte, tlsRecordPrefix)
|
||||
|
||||
n, err := conn.Read(prefix)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if n != tlsRecordPrefix {
|
||||
_ = conn.Close()
|
||||
return nil, nil, fmt.Errorf("short read handshake")
|
||||
}
|
||||
|
||||
conn = &muxConn{
|
||||
Conn: conn,
|
||||
buf: prefix,
|
||||
}
|
||||
|
||||
// tls
|
||||
if m.tlsMatcher(bytes.NewReader(prefix)) {
|
||||
return m.credentials.ServerHandshake(conn)
|
||||
}
|
||||
|
||||
// non-tls
|
||||
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
|
||||
}
|
||||
|
||||
func (m *muxTransportCredentials) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{
|
||||
ProtocolVersion: "",
|
||||
SecurityProtocol: "mux",
|
||||
ServerName: "",
|
||||
}
|
||||
}
|
||||
|
||||
func (m *muxTransportCredentials) Clone() credentials.TransportCredentials {
|
||||
return &muxTransportCredentials{
|
||||
tlsMatcher: cmux.TLS(),
|
||||
credentials: m.credentials.Clone(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *muxTransportCredentials) OverrideServerName(s string) error {
|
||||
return m.credentials.OverrideServerName(s)
|
||||
}
|
||||
|
||||
// info contains the auth information for an insecure connection.
|
||||
// It implements the AuthInfo interface.
|
||||
type info struct {
|
||||
credentials.CommonAuthInfo
|
||||
}
|
||||
|
||||
// AuthType returns the type of info as a string.
|
||||
func (info) AuthType() string {
|
||||
return "insecure"
|
||||
}
|
||||
|
||||
type muxConn struct {
|
||||
net.Conn
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func (m *muxConn) Read(b []byte) (int, error) {
|
||||
if len(m.buf) == 0 {
|
||||
return m.Conn.Read(b)
|
||||
}
|
||||
|
||||
wn := copy(b, m.buf)
|
||||
if wn < len(m.buf) {
|
||||
m.buf = m.buf[wn:]
|
||||
return wn, nil
|
||||
}
|
||||
|
||||
m.buf = nil
|
||||
b = b[wn:]
|
||||
|
||||
n, err := m.Conn.Read(b)
|
||||
return n + wn, err
|
||||
}
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Copyright 2022 The Dragonfly Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
testifyassert "github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type testConn struct {
|
||||
net.Conn
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func (conn *testConn) Read(b []byte) (int, error) {
|
||||
return conn.buf.Read(b)
|
||||
}
|
||||
|
||||
func Test_muxConn(t *testing.T) {
|
||||
assert := testifyassert.New(t)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
bufSize int
|
||||
data []byte
|
||||
}{
|
||||
{
|
||||
name: "buf size equal data size",
|
||||
bufSize: 4,
|
||||
data: []byte("hell"),
|
||||
},
|
||||
{
|
||||
name: "buf size less than data size",
|
||||
bufSize: 4,
|
||||
data: []byte("hello world"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mc := &muxConn{
|
||||
buf: tc.data[:tc.bufSize],
|
||||
Conn: &testConn{Conn: nil, buf: bytes.NewBuffer(tc.data[tc.bufSize:])},
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadAll(mc)
|
||||
assert.Nil(err, "read all should ok")
|
||||
assert.Equal(tc.data, data, "data shloud be same")
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue