feat: add grpc mux transport (#1602)
Signed-off-by: Jim Ma <majinjing3@gmail.com>
This commit is contained in:
parent
9ab33635c5
commit
98fb1fc427
|
|
@ -20,12 +20,14 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
|
||||||
"github.com/johanbrandhorst/certify"
|
"github.com/johanbrandhorst/certify"
|
||||||
|
|
||||||
|
"d7y.io/dragonfly/v2/pkg/net/ip"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetCertificate(certifyClient *certify.Certify) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
func GetCertificate(certifyClient *certify.Certify) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
// FIXME peers need pure ip cert, certify checks the ServerName, so workaround here
|
// FIXME peers need pure ip cert, certify checks the ServerName, so workaround here
|
||||||
hello.ServerName = "peer"
|
hello.ServerName = ip.IPv4
|
||||||
return certifyClient.GetCertificate(hello)
|
return certifyClient.GetCertificate(hello)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -174,11 +174,26 @@ func ConvertPattern(p string, defaultPattern commonv1.Pattern) commonv1.Pattern
|
||||||
}
|
}
|
||||||
|
|
||||||
type GlobalSecurityOption struct {
|
type GlobalSecurityOption struct {
|
||||||
AutoIssueCert bool `mapstructure:"autoIssueCert" yaml:"autoIssueCert"`
|
// AutoIssueCert indicates to issue client certificates for all grpc call
|
||||||
CACert serialize.PEMContent `mapstructure:"caCert" yaml:"caCert"`
|
// if AutoIssueCert is false, any other option in Security will be ignored
|
||||||
TLSVerify bool `mapstructure:"tlsVerify" yaml:"tlsVerify"`
|
AutoIssueCert bool `mapstructure:"autoIssueCert" yaml:"autoIssueCert"`
|
||||||
|
// CACert is the root CA certificate for all grpc tls handshake, it can be path or PEM format string
|
||||||
|
CACert serialize.PEMContent `mapstructure:"caCert" yaml:"caCert"`
|
||||||
|
// TLSPrefer indicates to verify client certificates for grpc ServerHandshake
|
||||||
|
TLSVerify bool `mapstructure:"tlsVerify" yaml:"tlsVerify"`
|
||||||
|
// TLSPolicy controls the grpc shandshake behaviors:
|
||||||
|
// force: both ClientHandshake and ServerHandshake are only support tls
|
||||||
|
// prefer: ServerHandshake supports tls and insecure (non-tls), ClientHandshake will only support tls
|
||||||
|
// default or empty: ServerHandshake supports tls and insecure (non-tls), ClientHandshake will only support insecure (non-tls)
|
||||||
|
TLSPolicy string `mapstructure:"tlsPolicy" yaml:"tlsPolicy"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
TLSPolicyForce = "force"
|
||||||
|
TLSPolicyPrefer = "prefer"
|
||||||
|
TLSPolicyDefault = "default"
|
||||||
|
)
|
||||||
|
|
||||||
type SchedulerOption struct {
|
type SchedulerOption struct {
|
||||||
// Manager is to get the scheduler configuration remotely.
|
// Manager is to get the scheduler configuration remotely.
|
||||||
Manager ManagerOption `mapstructure:"manager" yaml:"manager"`
|
Manager ManagerOption `mapstructure:"manager" yaml:"manager"`
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ var peerHostConfig = func() *DaemonOption {
|
||||||
DownloadGRPC: ListenOption{
|
DownloadGRPC: ListenOption{
|
||||||
Security: SecurityOption{
|
Security: SecurityOption{
|
||||||
Insecure: true,
|
Insecure: true,
|
||||||
TLSVerify: true,
|
TLSVerify: false,
|
||||||
},
|
},
|
||||||
UnixListen: &UnixListenOption{},
|
UnixListen: &UnixListenOption{},
|
||||||
},
|
},
|
||||||
|
|
@ -176,5 +176,11 @@ var peerHostConfig = func() *DaemonOption {
|
||||||
Duration: time.Minute,
|
Duration: time.Minute,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Security: GlobalSecurityOption{
|
||||||
|
AutoIssueCert: false,
|
||||||
|
CACert: serialize.PEMContent(""),
|
||||||
|
TLSVerify: false,
|
||||||
|
TLSPolicy: TLSPolicyDefault,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ import (
|
||||||
"d7y.io/dragonfly/v2/pkg/dfnet"
|
"d7y.io/dragonfly/v2/pkg/dfnet"
|
||||||
"d7y.io/dragonfly/v2/pkg/net/fqdn"
|
"d7y.io/dragonfly/v2/pkg/net/fqdn"
|
||||||
"d7y.io/dragonfly/v2/pkg/net/ip"
|
"d7y.io/dragonfly/v2/pkg/net/ip"
|
||||||
|
"d7y.io/dragonfly/v2/pkg/serialize"
|
||||||
)
|
)
|
||||||
|
|
||||||
var peerHostConfig = func() *DaemonOption {
|
var peerHostConfig = func() *DaemonOption {
|
||||||
|
|
@ -81,7 +82,7 @@ var peerHostConfig = func() *DaemonOption {
|
||||||
DownloadGRPC: ListenOption{
|
DownloadGRPC: ListenOption{
|
||||||
Security: SecurityOption{
|
Security: SecurityOption{
|
||||||
Insecure: true,
|
Insecure: true,
|
||||||
TLSVerify: true,
|
TLSVerify: false,
|
||||||
},
|
},
|
||||||
UnixListen: &UnixListenOption{},
|
UnixListen: &UnixListenOption{},
|
||||||
},
|
},
|
||||||
|
|
@ -175,5 +176,11 @@ var peerHostConfig = func() *DaemonOption {
|
||||||
Duration: time.Minute,
|
Duration: time.Minute,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Security: GlobalSecurityOption{
|
||||||
|
AutoIssueCert: false,
|
||||||
|
CACert: serialize.PEMContent(""),
|
||||||
|
TLSVerify: false,
|
||||||
|
TLSPolicy: TLSPolicyDefault,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -145,13 +145,15 @@ func New(opt *config.DaemonOption, d dfpath.Dfpath) (Daemon, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// issue a certificate to reduce first time delay
|
// issue a certificate to reduce first time delay
|
||||||
_, err := certifyClient.GetCertificate(&tls.ClientHelloInfo{
|
cert, err := certifyClient.GetCertificate(&tls.ClientHelloInfo{
|
||||||
ServerName: ip.IPv4,
|
ServerName: ip.IPv4,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("issue certificate error: %s", err.Error())
|
logger.Errorf("issue certificate error: %s", err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
logger.Debugf("request cert from manager, common name: %s, issuer: %s",
|
||||||
|
cert.Leaf.Subject.CommonName, cert.Leaf.Issuer.CommonName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New dynconfig manager client.
|
// New dynconfig manager client.
|
||||||
|
|
@ -337,6 +339,7 @@ func loadGPRCTLSCredentials(opt config.SecurityOption, certifyClient *certify.Ce
|
||||||
}
|
}
|
||||||
|
|
||||||
opt.TLSConfig.ClientCAs = certPool
|
opt.TLSConfig.ClientCAs = certPool
|
||||||
|
opt.TLSConfig.RootCAs = certPool
|
||||||
|
|
||||||
// Load server's certificate and private key
|
// Load server's certificate and private key
|
||||||
if certifyClient == nil {
|
if certifyClient == nil {
|
||||||
|
|
@ -347,7 +350,6 @@ func loadGPRCTLSCredentials(opt config.SecurityOption, certifyClient *certify.Ce
|
||||||
opt.TLSConfig.Certificates = []tls.Certificate{serverCert}
|
opt.TLSConfig.Certificates = []tls.Certificate{serverCert}
|
||||||
} else {
|
} else {
|
||||||
// enable auto issue certificate
|
// enable auto issue certificate
|
||||||
opt.TLSConfig.Certificates = nil
|
|
||||||
opt.TLSConfig.GetCertificate = config.GetCertificate(certifyClient)
|
opt.TLSConfig.GetCertificate = config.GetCertificate(certifyClient)
|
||||||
opt.TLSConfig.GetClientCertificate = certifyClient.GetClientCertificate
|
opt.TLSConfig.GetClientCertificate = certifyClient.GetClientCertificate
|
||||||
}
|
}
|
||||||
|
|
@ -356,7 +358,15 @@ func loadGPRCTLSCredentials(opt config.SecurityOption, certifyClient *certify.Ce
|
||||||
opt.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
opt.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
}
|
}
|
||||||
|
|
||||||
return credentials.NewTLS(opt.TLSConfig), nil
|
switch security.TLSPolicy {
|
||||||
|
case config.TLSPolicyDefault, config.TLSPolicyPrefer:
|
||||||
|
return rpc.NewMuxTransportCredentials(opt.TLSConfig,
|
||||||
|
rpc.WithTLSPreferClientHandshake(security.TLSPolicy == config.TLSPolicyPrefer)), nil
|
||||||
|
case config.TLSPolicyForce:
|
||||||
|
return credentials.NewTLS(opt.TLSConfig), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid tlsPolicy: %s", security.TLSPolicy)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadGlobalGPRCTLSCredentials(certifyClient *certify.Certify, security config.GlobalSecurityOption) (credentials.TransportCredentials, error) {
|
func loadGlobalGPRCTLSCredentials(certifyClient *certify.Certify, security config.GlobalSecurityOption) (credentials.TransportCredentials, error) {
|
||||||
|
|
@ -370,17 +380,26 @@ func loadGlobalGPRCTLSCredentials(certifyClient *certify.Certify, security confi
|
||||||
return nil, fmt.Errorf("failed to add global CA's certificate")
|
return nil, fmt.Errorf("failed to add global CA's certificate")
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
ClientCAs: certPool,
|
ClientCAs: certPool,
|
||||||
|
RootCAs: certPool,
|
||||||
GetCertificate: config.GetCertificate(certifyClient),
|
GetCertificate: config.GetCertificate(certifyClient),
|
||||||
GetClientCertificate: certifyClient.GetClientCertificate,
|
GetClientCertificate: certifyClient.GetClientCertificate,
|
||||||
}
|
}
|
||||||
|
|
||||||
if security.TLSVerify {
|
if security.TLSVerify {
|
||||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
}
|
}
|
||||||
|
|
||||||
return credentials.NewTLS(config), nil
|
switch security.TLSPolicy {
|
||||||
|
case config.TLSPolicyDefault, config.TLSPolicyPrefer:
|
||||||
|
return rpc.NewMuxTransportCredentials(tlsConfig,
|
||||||
|
rpc.WithTLSPreferClientHandshake(security.TLSPolicy == config.TLSPolicyPrefer)), nil
|
||||||
|
case config.TLSPolicyForce:
|
||||||
|
return credentials.NewTLS(tlsConfig), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid tlsPolicy: %s", security.TLSPolicy)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*clientDaemon) prepareTCPListener(opt config.ListenOption, withTLS bool) (net.Listener, int, error) {
|
func (*clientDaemon) prepareTCPListener(opt config.ListenOption, withTLS bool) (net.Listener, int, error) {
|
||||||
|
|
|
||||||
|
|
@ -485,7 +485,7 @@ func (pm *pieceManager) downloadKnownLengthSource(ctx context.Context, pt Task,
|
||||||
|
|
||||||
pt.ReportPieceResult(request, result, nil)
|
pt.ReportPieceResult(request, result, nil)
|
||||||
pt.PublishPieceInfo(pieceNum, uint32(result.Size))
|
pt.PublishPieceInfo(pieceNum, uint32(result.Size))
|
||||||
if supportConcurrent {
|
if supportConcurrent && pieceNum+2 < maxPieceNum {
|
||||||
// the time unit of FinishTime and BeginTime is ns
|
// the time unit of FinishTime and BeginTime is ns
|
||||||
speed := float64(pieceSize) / float64((result.FinishTime-result.BeginTime)/1000000)
|
speed := float64(pieceSize) / float64((result.FinishTime-result.BeginTime)/1000000)
|
||||||
if speed < float64(pm.concurrentOption.ThresholdSpeed) {
|
if speed < float64(pm.concurrentOption.ThresholdSpeed) {
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,8 @@ import (
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
grpcpeer "google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/emptypb"
|
"google.golang.org/protobuf/types/known/emptypb"
|
||||||
|
|
||||||
|
|
@ -211,7 +213,28 @@ func (s *server) sendFirstPieceTasks(
|
||||||
return sendExistPieces(sync.Context(), log, get, request, sync, sentMap, false)
|
return sendExistPieces(sync.Context(), log, get, request, sync, sentMap, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func printAuthInfo(ctx context.Context) {
|
||||||
|
if peerInfo, ok := grpcpeer.FromContext(ctx); ok {
|
||||||
|
if tlsInfo, ok := peerInfo.AuthInfo.(credentials.TLSInfo); ok {
|
||||||
|
for i, pc := range tlsInfo.State.PeerCertificates {
|
||||||
|
logger.Debugf("peer cert depth %d, issuer: %#v", i, pc.Issuer.CommonName)
|
||||||
|
logger.Debugf("peer cert depth %d, common name: %#v", i, pc.Subject.CommonName)
|
||||||
|
if len(pc.IPAddresses) > 0 {
|
||||||
|
logger.Debugf("peer cert depth %d, ip: %#v", i, pc.IPAddresses)
|
||||||
|
}
|
||||||
|
if len(pc.DNSNames) > 0 {
|
||||||
|
logger.Debugf("peer cert depth %d, dns: %#v", i, pc.DNSNames)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *server) SyncPieceTasks(sync dfdaemonv1.Daemon_SyncPieceTasksServer) error {
|
func (s *server) SyncPieceTasks(sync dfdaemonv1.Daemon_SyncPieceTasksServer) error {
|
||||||
|
if logger.IsDebug() {
|
||||||
|
printAuthInfo(sync.Context())
|
||||||
|
}
|
||||||
|
|
||||||
request, err := sync.Recv()
|
request, err := sync.Recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("receive first sync piece tasks request error: %s", err.Error())
|
logger.Errorf("receive first sync piece tasks request error: %s", err.Error())
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,10 @@ func (s *seeder) SyncPieceTasks(tasksServer cdnsystemv1.Seeder_SyncPieceTasksSer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *seeder) ObtainSeeds(seedRequest *cdnsystemv1.SeedRequest, seedsServer cdnsystemv1.Seeder_ObtainSeedsServer) error {
|
func (s *seeder) ObtainSeeds(seedRequest *cdnsystemv1.SeedRequest, seedsServer cdnsystemv1.Seeder_ObtainSeedsServer) error {
|
||||||
|
if logger.IsDebug() {
|
||||||
|
printAuthInfo(seedsServer.Context())
|
||||||
|
}
|
||||||
|
|
||||||
metrics.SeedPeerConcurrentDownloadGauge.Inc()
|
metrics.SeedPeerConcurrentDownloadGauge.Inc()
|
||||||
defer metrics.SeedPeerConcurrentDownloadGauge.Dec()
|
defer metrics.SeedPeerConcurrentDownloadGauge.Dec()
|
||||||
metrics.SeedPeerDownloadCount.Add(1)
|
metrics.SeedPeerDownloadCount.Add(1)
|
||||||
|
|
|
||||||
|
|
@ -215,6 +215,10 @@ func (log *SugaredLoggerOnWith) Debug(args ...any) {
|
||||||
CoreLogger.Debugw(fmt.Sprint(args...), log.withArgs...)
|
CoreLogger.Debugw(fmt.Sprint(args...), log.withArgs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (log *SugaredLoggerOnWith) IsDebug() bool {
|
||||||
|
return coreLogLevelEnabler.Enabled(zap.DebugLevel)
|
||||||
|
}
|
||||||
|
|
||||||
func Infof(template string, args ...any) {
|
func Infof(template string, args ...any) {
|
||||||
CoreLogger.Infof(template, args...)
|
CoreLogger.Infof(template, args...)
|
||||||
}
|
}
|
||||||
|
|
@ -243,6 +247,14 @@ func Debugf(template string, args ...any) {
|
||||||
CoreLogger.Debugf(template, args...)
|
CoreLogger.Debugf(template, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Debug(args ...any) {
|
||||||
|
CoreLogger.Debug(args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsDebug() bool {
|
||||||
|
return coreLogLevelEnabler.Enabled(zap.DebugLevel)
|
||||||
|
}
|
||||||
|
|
||||||
func Fatalf(template string, args ...any) {
|
func Fatalf(template string, args ...any) {
|
||||||
CoreLogger.Fatalf(template, args...)
|
CoreLogger.Fatalf(template, args...)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,145 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2022 The Dragonfly Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package rpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/soheilhy/cmux"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// cmux's TLS matcher read at least 3 + 1 bytes
|
||||||
|
tlsRecordPrefix = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
type muxTransportCredentials struct {
|
||||||
|
credentials credentials.TransportCredentials
|
||||||
|
tlsMatcher cmux.Matcher
|
||||||
|
tlsPrefer bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithTLSPreferClientHandshake(prefer bool) func(m *muxTransportCredentials) {
|
||||||
|
return func(m *muxTransportCredentials) {
|
||||||
|
m.tlsPrefer = prefer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMuxTransportCredentials(c *tls.Config, opts ...func(m *muxTransportCredentials)) credentials.TransportCredentials {
|
||||||
|
m := &muxTransportCredentials{
|
||||||
|
tlsMatcher: cmux.TLS(),
|
||||||
|
credentials: credentials.NewTLS(c),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *muxTransportCredentials) ClientHandshake(ctx context.Context, s string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
if m.tlsPrefer {
|
||||||
|
return m.credentials.ClientHandshake(ctx, s, conn)
|
||||||
|
}
|
||||||
|
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *muxTransportCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
var prefix = make([]byte, tlsRecordPrefix)
|
||||||
|
|
||||||
|
n, err := conn.Read(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != tlsRecordPrefix {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, nil, fmt.Errorf("short read handshake")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn = &muxConn{
|
||||||
|
Conn: conn,
|
||||||
|
buf: prefix,
|
||||||
|
}
|
||||||
|
|
||||||
|
// tls
|
||||||
|
if m.tlsMatcher(bytes.NewReader(prefix)) {
|
||||||
|
return m.credentials.ServerHandshake(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// non-tls
|
||||||
|
return conn, info{credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *muxTransportCredentials) Info() credentials.ProtocolInfo {
|
||||||
|
return credentials.ProtocolInfo{
|
||||||
|
ProtocolVersion: "",
|
||||||
|
SecurityProtocol: "mux",
|
||||||
|
ServerName: "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *muxTransportCredentials) Clone() credentials.TransportCredentials {
|
||||||
|
return &muxTransportCredentials{
|
||||||
|
tlsMatcher: cmux.TLS(),
|
||||||
|
credentials: m.credentials.Clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *muxTransportCredentials) OverrideServerName(s string) error {
|
||||||
|
return m.credentials.OverrideServerName(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// info contains the auth information for an insecure connection.
|
||||||
|
// It implements the AuthInfo interface.
|
||||||
|
type info struct {
|
||||||
|
credentials.CommonAuthInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthType returns the type of info as a string.
|
||||||
|
func (info) AuthType() string {
|
||||||
|
return "insecure"
|
||||||
|
}
|
||||||
|
|
||||||
|
type muxConn struct {
|
||||||
|
net.Conn
|
||||||
|
buf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *muxConn) Read(b []byte) (int, error) {
|
||||||
|
if len(m.buf) == 0 {
|
||||||
|
return m.Conn.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
wn := copy(b, m.buf)
|
||||||
|
if wn < len(m.buf) {
|
||||||
|
m.buf = m.buf[wn:]
|
||||||
|
return wn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.buf = nil
|
||||||
|
b = b[wn:]
|
||||||
|
|
||||||
|
n, err := m.Conn.Read(b)
|
||||||
|
return n + wn, err
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2022 The Dragonfly Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package rpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
testifyassert "github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testConn struct {
|
||||||
|
net.Conn
|
||||||
|
buf *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *testConn) Read(b []byte) (int, error) {
|
||||||
|
return conn.buf.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_muxConn(t *testing.T) {
|
||||||
|
assert := testifyassert.New(t)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
bufSize int
|
||||||
|
data []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "buf size equal data size",
|
||||||
|
bufSize: 4,
|
||||||
|
data: []byte("hell"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "buf size less than data size",
|
||||||
|
bufSize: 4,
|
||||||
|
data: []byte("hello world"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mc := &muxConn{
|
||||||
|
buf: tc.data[:tc.bufSize],
|
||||||
|
Conn: &testConn{Conn: nil, buf: bytes.NewBuffer(tc.data[tc.bufSize:])},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := ioutil.ReadAll(mc)
|
||||||
|
assert.Nil(err, "read all should ok")
|
||||||
|
assert.Equal(tc.data, data, "data shloud be same")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue