diff --git a/manager/database/database.go b/manager/database/database.go index dc03d812f..e978c7774 100644 --- a/manager/database/database.go +++ b/manager/database/database.go @@ -90,7 +90,6 @@ func New(cfg *config.Config) (*Database, error) { SentinelPassword: cfg.Database.Redis.SentinelPassword, }) if err != nil { - logger.Errorf("redis: %s", err.Error()) return nil, err } diff --git a/pkg/redis/proxy.go b/pkg/redis/proxy.go index 8acc3565e..4fbedf6e7 100644 --- a/pkg/redis/proxy.go +++ b/pkg/redis/proxy.go @@ -17,6 +17,7 @@ package redis import ( + "bufio" "io" "net" "sync" @@ -79,6 +80,13 @@ func (p *proxy) Stop() { // handleConn handles the incoming connection and establishes a connection to the remote host. func (p *proxy) handleConn(conn net.Conn) { defer conn.Close() + + reader, isRedisProtocol := p.isRedisProtocol(conn) + if !isRedisProtocol { + logger.Errorf("not a redis protocol: %s", conn.RemoteAddr()) + return + } + rConn, err := net.Dial("tcp", p.to) if err != nil { logger.Errorf("error dialing remote host: %v", err) @@ -89,7 +97,7 @@ func (p *proxy) handleConn(conn net.Conn) { wg := &sync.WaitGroup{} wg.Add(2) go p.copy(rConn, conn, wg) - go p.copy(conn, rConn, wg) + go p.copyReader(reader, rConn, wg) wg.Wait() } @@ -101,9 +109,44 @@ func (p *proxy) copy(from, to net.Conn, wg *sync.WaitGroup) { return default: if _, err := io.Copy(to, from); err != nil { - logger.Errorf("error copy: %v", err) + logger.Errorf("error copying from %s to %s: %v", from.RemoteAddr(), to.RemoteAddr(), err) p.Stop() return } } } + +// copyReader copies data from a reader to a connection. +func (p *proxy) copyReader(from io.Reader, to net.Conn, wg *sync.WaitGroup) { + defer wg.Done() + select { + case <-p.done: + return + default: + if _, err := io.Copy(to, from); err != nil { + logger.Errorf("error copying to %s: %v", to.RemoteAddr(), err) + p.Stop() + return + } + } +} + +// isRedisProtocol checks if the connection uses the Redis protocol. +func (p *proxy) isRedisProtocol(conn net.Conn) (io.Reader, bool) { + reader := bufio.NewReader(conn) + firstByte, err := reader.Peek(1) + if err != nil { + if err != io.EOF { + logger.Errorf("reading first byte from client failed: %s: %v", conn.RemoteAddr(), err) + } + + return reader, false + } + + switch firstByte[0] { + case '*', '+', '-', ':', '$': + return reader, true + } + + return reader, false +} diff --git a/pkg/redis/redis.go b/pkg/redis/redis.go index eebb82037..73ac37f3a 100644 --- a/pkg/redis/redis.go +++ b/pkg/redis/redis.go @@ -22,6 +22,7 @@ import ( "github.com/redis/go-redis/v9" + logger "d7y.io/dragonfly/v2/internal/dflog" "d7y.io/dragonfly/v2/pkg/types" ) @@ -76,6 +77,7 @@ func NewRedis(cfg *redis.UniversalOptions) (redis.UniversalClient, error) { }) if err := client.Ping(context.Background()).Err(); err != nil { + logger.Errorf("failed to ping redis: %v", err) return nil, err }