chore: avoid use default mux (#3346)

Signed-off-by: Jim Ma <majinjing3@gmail.com>
This commit is contained in:
Jim Ma 2024-07-03 19:50:50 +08:00 committed by GitHub
parent fae3613ff2
commit d66cf41c38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 79 deletions

View File

@ -21,7 +21,6 @@ import (
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -80,7 +79,7 @@ type Proxy struct {
cacheRWMutex sync.RWMutex cacheRWMutex sync.RWMutex
// directHandler are used to handle non-proxy requests // directHandler are used to handle non-proxy requests
directHandler http.Handler directHandler *http.ServeMux
// transport is used to handle http proxy requests // transport is used to handle http proxy requests
transport http.RoundTripper transport http.RoundTripper
@ -175,25 +174,6 @@ func WithCert(cert *tls.Certificate) Option {
} }
} }
// WithDirectHandler sets the handler for non-proxy requests
func WithDirectHandler(h *http.ServeMux) Option {
return func(p *Proxy) *Proxy {
if p.registry == nil || p.registry.Remote == nil || p.registry.Remote.URL == nil {
logger.Warnf("registry mirror url is empty, registry mirror feature is disabled")
h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, fmt.Sprintf("registry mirror feature is disabled"), http.StatusNotFound)
})
p.directHandler = h
return p
}
// Make sure the root handler of the given server mux is the
// registry mirror reverse proxy
h.HandleFunc("/", p.mirrorRegistry)
p.directHandler = h
return p
}
}
// WithRules sets the proxy rules // WithRules sets the proxy rules
func WithRules(rules []*config.ProxyRule) Option { func WithRules(rules []*config.ProxyRule) Option {
return func(p *Proxy) *Proxy { return func(p *Proxy) *Proxy {
@ -290,9 +270,26 @@ func NewProxy(options ...Option) (*Proxy, error) {
if proxy.transport == nil { if proxy.transport == nil {
proxy.transport = proxy.newTransport(nil) proxy.transport = proxy.newTransport(nil)
} }
// check register mirror config and register handler
proxy.updateMirrorHandler()
return proxy, nil return proxy, nil
} }
func (proxy *Proxy) updateMirrorHandler() {
h := proxy.directHandler
if proxy.registry == nil || proxy.registry.Remote == nil || proxy.registry.Remote.URL == nil {
logger.Warnf("registry mirror url is empty, registry mirror feature is disabled")
h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "registry mirror feature is disabled", http.StatusNotFound)
})
return
}
// Make sure the root handler of the given server mux is the
// registry mirror reverse proxy
h.HandleFunc("/", proxy.mirrorRegistry)
}
func isBasicAuthMatch(basicAuth *config.BasicAuth, user, pass string) bool { func isBasicAuthMatch(basicAuth *config.BasicAuth, user, pass string) bool {
usernameOK := subtle.ConstantTimeCompare([]byte(basicAuth.Username), []byte(user)) == 1 usernameOK := subtle.ConstantTimeCompare([]byte(basicAuth.Username), []byte(user)) == 1
passwordOK := subtle.ConstantTimeCompare([]byte(basicAuth.Password), []byte(pass)) == 1 passwordOK := subtle.ConstantTimeCompare([]byte(basicAuth.Password), []byte(pass)) == 1

View File

@ -22,14 +22,10 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"os"
"reflect"
"github.com/spf13/viper"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
schedulerv1 "d7y.io/api/v2/pkg/apis/scheduler/v1" schedulerv1 "d7y.io/api/v2/pkg/apis/scheduler/v1"
@ -137,7 +133,6 @@ func NewProxyManager(peerHost *schedulerv1.PeerHost, peerTaskManager peer.TaskMa
} }
func (pm *proxyManager) Serve(listener net.Listener) error { func (pm *proxyManager) Serve(listener net.Listener) error {
_ = WithDirectHandler(newDirectHandler())(pm.Proxy)
pm.Server.Handler = pm.Proxy pm.Server.Handler = pm.Proxy
return pm.Server.Serve(listener) return pm.Server.Serve(listener)
} }
@ -179,59 +174,6 @@ func (pm *proxyManager) Watch(opt *config.ProxyOption) {
} }
} }
func newDirectHandler() *http.ServeMux {
s := http.DefaultServeMux
s.HandleFunc("/args", getArgs)
s.HandleFunc("/env", getEnv)
return s
}
// getEnv returns the environments of dfdaemon.
func getEnv(w http.ResponseWriter, r *http.Request) {
logger.Debugf("access: %s", r.URL.String())
if err := json.NewEncoder(w).Encode(ensureStringKey(viper.AllSettings())); err != nil {
logger.Errorf("failed to encode env json: %v", err)
}
}
// ensureStringKey recursively ensures all maps in the given interface are string,
// to make the result marshalable by json. This is meant to be used with viper
// settings, so only maps and slices are handled.
func ensureStringKey(obj any) any {
rt, rv := reflect.TypeOf(obj), reflect.ValueOf(obj)
switch rt.Kind() {
case reflect.Map:
res := make(map[string]any)
for _, k := range rv.MapKeys() {
res[fmt.Sprintf("%v", k.Interface())] = ensureStringKey(rv.MapIndex(k).Interface())
}
return res
case reflect.Slice:
res := make([]any, rv.Len())
for i := 0; i < rv.Len(); i++ {
res[i] = ensureStringKey(rv.Index(i).Interface())
}
return res
}
return obj
}
// getArgs returns all the arguments of command-line except the program name.
func getArgs(w http.ResponseWriter, r *http.Request) {
logger.Debugf("access: %s", r.URL.String())
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "text/plain;charset=utf-8")
for index, value := range os.Args {
if index > 0 {
if _, err := w.Write([]byte(value + " ")); err != nil {
logger.Errorf("failed to respond information: %v", err)
}
}
}
}
func certFromFile(certFile string, keyFile string) (*tls.Certificate, error) { func certFromFile(certFile string, keyFile string) (*tls.Certificate, error) {
// cert.Certificate is a chain of one or more certificates, leaf first. // cert.Certificate is a chain of one or more certificates, leaf first.
cert, err := tls.LoadX509KeyPair(certFile, keyFile) cert, err := tls.LoadX509KeyPair(certFile, keyFile)