refactor(manager): redis proxy to use bufio.Reader for protocol detection (#4062)

Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
Gaius 2025-05-13 17:28:31 +08:00 committed by GitHub
parent b94a1c4edf
commit b2664ab56f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 3 deletions

View File

@ -90,7 +90,6 @@ func New(cfg *config.Config) (*Database, error) {
SentinelPassword: cfg.Database.Redis.SentinelPassword, SentinelPassword: cfg.Database.Redis.SentinelPassword,
}) })
if err != nil { if err != nil {
logger.Errorf("redis: %s", err.Error())
return nil, err return nil, err
} }

View File

@ -17,6 +17,7 @@
package redis package redis
import ( import (
"bufio"
"io" "io"
"net" "net"
"sync" "sync"
@ -79,6 +80,13 @@ func (p *proxy) Stop() {
// handleConn handles the incoming connection and establishes a connection to the remote host. // handleConn handles the incoming connection and establishes a connection to the remote host.
func (p *proxy) handleConn(conn net.Conn) { func (p *proxy) handleConn(conn net.Conn) {
defer conn.Close() defer conn.Close()
reader, isRedisProtocol := p.isRedisProtocol(conn)
if !isRedisProtocol {
logger.Errorf("not a redis protocol: %s", conn.RemoteAddr())
return
}
rConn, err := net.Dial("tcp", p.to) rConn, err := net.Dial("tcp", p.to)
if err != nil { if err != nil {
logger.Errorf("error dialing remote host: %v", err) logger.Errorf("error dialing remote host: %v", err)
@ -89,7 +97,7 @@ func (p *proxy) handleConn(conn net.Conn) {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(2) wg.Add(2)
go p.copy(rConn, conn, wg) go p.copy(rConn, conn, wg)
go p.copy(conn, rConn, wg) go p.copyReader(reader, rConn, wg)
wg.Wait() wg.Wait()
} }
@ -101,9 +109,44 @@ func (p *proxy) copy(from, to net.Conn, wg *sync.WaitGroup) {
return return
default: default:
if _, err := io.Copy(to, from); err != nil { if _, err := io.Copy(to, from); err != nil {
logger.Errorf("error copy: %v", err) logger.Errorf("error copying from %s to %s: %v", from.RemoteAddr(), to.RemoteAddr(), err)
p.Stop() p.Stop()
return return
} }
} }
} }
// copyReader copies data from a reader to a connection.
func (p *proxy) copyReader(from io.Reader, to net.Conn, wg *sync.WaitGroup) {
defer wg.Done()
select {
case <-p.done:
return
default:
if _, err := io.Copy(to, from); err != nil {
logger.Errorf("error copying to %s: %v", to.RemoteAddr(), err)
p.Stop()
return
}
}
}
// isRedisProtocol checks if the connection uses the Redis protocol.
func (p *proxy) isRedisProtocol(conn net.Conn) (io.Reader, bool) {
reader := bufio.NewReader(conn)
firstByte, err := reader.Peek(1)
if err != nil {
if err != io.EOF {
logger.Errorf("reading first byte from client failed: %s: %v", conn.RemoteAddr(), err)
}
return reader, false
}
switch firstByte[0] {
case '*', '+', '-', ':', '$':
return reader, true
}
return reader, false
}

View File

@ -22,6 +22,7 @@ import (
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
logger "d7y.io/dragonfly/v2/internal/dflog"
"d7y.io/dragonfly/v2/pkg/types" "d7y.io/dragonfly/v2/pkg/types"
) )
@ -76,6 +77,7 @@ func NewRedis(cfg *redis.UniversalOptions) (redis.UniversalClient, error) {
}) })
if err := client.Ping(context.Background()).Err(); err != nil { if err := client.Ping(context.Background()).Err(); err != nil {
logger.Errorf("failed to ping redis: %v", err)
return nil, err return nil, err
} }