refactor(manager): redis proxy to use bufio.Reader for protocol detection (#4062)
Signed-off-by: Gaius <gaius.qi@gmail.com>
This commit is contained in:
parent
b94a1c4edf
commit
b2664ab56f
|
|
@ -90,7 +90,6 @@ func New(cfg *config.Config) (*Database, error) {
|
||||||
SentinelPassword: cfg.Database.Redis.SentinelPassword,
|
SentinelPassword: cfg.Database.Redis.SentinelPassword,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("redis: %s", err.Error())
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
package redis
|
package redis
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
@ -79,6 +80,13 @@ func (p *proxy) Stop() {
|
||||||
// handleConn handles the incoming connection and establishes a connection to the remote host.
|
// handleConn handles the incoming connection and establishes a connection to the remote host.
|
||||||
func (p *proxy) handleConn(conn net.Conn) {
|
func (p *proxy) handleConn(conn net.Conn) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
reader, isRedisProtocol := p.isRedisProtocol(conn)
|
||||||
|
if !isRedisProtocol {
|
||||||
|
logger.Errorf("not a redis protocol: %s", conn.RemoteAddr())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
rConn, err := net.Dial("tcp", p.to)
|
rConn, err := net.Dial("tcp", p.to)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("error dialing remote host: %v", err)
|
logger.Errorf("error dialing remote host: %v", err)
|
||||||
|
|
@ -89,7 +97,7 @@ func (p *proxy) handleConn(conn net.Conn) {
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
go p.copy(rConn, conn, wg)
|
go p.copy(rConn, conn, wg)
|
||||||
go p.copy(conn, rConn, wg)
|
go p.copyReader(reader, rConn, wg)
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -101,9 +109,44 @@ func (p *proxy) copy(from, to net.Conn, wg *sync.WaitGroup) {
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
if _, err := io.Copy(to, from); err != nil {
|
if _, err := io.Copy(to, from); err != nil {
|
||||||
logger.Errorf("error copy: %v", err)
|
logger.Errorf("error copying from %s to %s: %v", from.RemoteAddr(), to.RemoteAddr(), err)
|
||||||
p.Stop()
|
p.Stop()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copyReader copies data from a reader to a connection.
|
||||||
|
func (p *proxy) copyReader(from io.Reader, to net.Conn, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if _, err := io.Copy(to, from); err != nil {
|
||||||
|
logger.Errorf("error copying to %s: %v", to.RemoteAddr(), err)
|
||||||
|
p.Stop()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRedisProtocol checks if the connection uses the Redis protocol.
|
||||||
|
func (p *proxy) isRedisProtocol(conn net.Conn) (io.Reader, bool) {
|
||||||
|
reader := bufio.NewReader(conn)
|
||||||
|
firstByte, err := reader.Peek(1)
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
logger.Errorf("reading first byte from client failed: %s: %v", conn.RemoteAddr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return reader, false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch firstByte[0] {
|
||||||
|
case '*', '+', '-', ':', '$':
|
||||||
|
return reader, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return reader, false
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import (
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
|
|
||||||
|
logger "d7y.io/dragonfly/v2/internal/dflog"
|
||||||
"d7y.io/dragonfly/v2/pkg/types"
|
"d7y.io/dragonfly/v2/pkg/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -76,6 +77,7 @@ func NewRedis(cfg *redis.UniversalOptions) (redis.UniversalClient, error) {
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := client.Ping(context.Background()).Err(); err != nil {
|
if err := client.Ping(context.Background()).Err(); err != nil {
|
||||||
|
logger.Errorf("failed to ping redis: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue