diff --git a/client/daemon/proxy/proxy.go b/client/daemon/proxy/proxy.go index d401cb5c6..9fc03a348 100644 --- a/client/daemon/proxy/proxy.go +++ b/client/daemon/proxy/proxy.go @@ -21,7 +21,6 @@ import ( "crypto/tls" "encoding/base64" "errors" - "fmt" "io" "net" "net/http" @@ -80,7 +79,7 @@ type Proxy struct { cacheRWMutex sync.RWMutex // directHandler are used to handle non-proxy requests - directHandler http.Handler + directHandler *http.ServeMux // transport is used to handle http proxy requests transport http.RoundTripper @@ -175,25 +174,6 @@ func WithCert(cert *tls.Certificate) Option { } } -// WithDirectHandler sets the handler for non-proxy requests -func WithDirectHandler(h *http.ServeMux) Option { - return func(p *Proxy) *Proxy { - if p.registry == nil || p.registry.Remote == nil || p.registry.Remote.URL == nil { - logger.Warnf("registry mirror url is empty, registry mirror feature is disabled") - h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - http.Error(w, fmt.Sprintf("registry mirror feature is disabled"), http.StatusNotFound) - }) - p.directHandler = h - return p - } - // Make sure the root handler of the given server mux is the - // registry mirror reverse proxy - h.HandleFunc("/", p.mirrorRegistry) - p.directHandler = h - return p - } -} - // WithRules sets the proxy rules func WithRules(rules []*config.ProxyRule) Option { return func(p *Proxy) *Proxy { @@ -290,9 +270,26 @@ func NewProxy(options ...Option) (*Proxy, error) { if proxy.transport == nil { proxy.transport = proxy.newTransport(nil) } + + // check register mirror config and register handler + proxy.updateMirrorHandler() return proxy, nil } +func (proxy *Proxy) updateMirrorHandler() { + h := proxy.directHandler + if proxy.registry == nil || proxy.registry.Remote == nil || proxy.registry.Remote.URL == nil { + logger.Warnf("registry mirror url is empty, registry mirror feature is disabled") + h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "registry mirror feature is disabled", http.StatusNotFound) + }) + return + } + // Make sure the root handler of the given server mux is the + // registry mirror reverse proxy + h.HandleFunc("/", proxy.mirrorRegistry) +} + func isBasicAuthMatch(basicAuth *config.BasicAuth, user, pass string) bool { usernameOK := subtle.ConstantTimeCompare([]byte(basicAuth.Username), []byte(user)) == 1 passwordOK := subtle.ConstantTimeCompare([]byte(basicAuth.Password), []byte(pass)) == 1 diff --git a/client/daemon/proxy/proxy_manager.go b/client/daemon/proxy/proxy_manager.go index 5a531d012..3da2b697f 100644 --- a/client/daemon/proxy/proxy_manager.go +++ b/client/daemon/proxy/proxy_manager.go @@ -22,14 +22,10 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/json" "fmt" "net" "net/http" - "os" - "reflect" - "github.com/spf13/viper" "gopkg.in/yaml.v3" schedulerv1 "d7y.io/api/v2/pkg/apis/scheduler/v1" @@ -137,7 +133,6 @@ func NewProxyManager(peerHost *schedulerv1.PeerHost, peerTaskManager peer.TaskMa } func (pm *proxyManager) Serve(listener net.Listener) error { - _ = WithDirectHandler(newDirectHandler())(pm.Proxy) pm.Server.Handler = pm.Proxy return pm.Server.Serve(listener) } @@ -179,59 +174,6 @@ func (pm *proxyManager) Watch(opt *config.ProxyOption) { } } -func newDirectHandler() *http.ServeMux { - s := http.DefaultServeMux - s.HandleFunc("/args", getArgs) - s.HandleFunc("/env", getEnv) - return s -} - -// getEnv returns the environments of dfdaemon. -func getEnv(w http.ResponseWriter, r *http.Request) { - logger.Debugf("access: %s", r.URL.String()) - if err := json.NewEncoder(w).Encode(ensureStringKey(viper.AllSettings())); err != nil { - logger.Errorf("failed to encode env json: %v", err) - } -} - -// ensureStringKey recursively ensures all maps in the given interface are string, -// to make the result marshalable by json. This is meant to be used with viper -// settings, so only maps and slices are handled. -func ensureStringKey(obj any) any { - rt, rv := reflect.TypeOf(obj), reflect.ValueOf(obj) - switch rt.Kind() { - case reflect.Map: - res := make(map[string]any) - for _, k := range rv.MapKeys() { - res[fmt.Sprintf("%v", k.Interface())] = ensureStringKey(rv.MapIndex(k).Interface()) - } - return res - case reflect.Slice: - res := make([]any, rv.Len()) - for i := 0; i < rv.Len(); i++ { - res[i] = ensureStringKey(rv.Index(i).Interface()) - } - return res - } - return obj -} - -// getArgs returns all the arguments of command-line except the program name. -func getArgs(w http.ResponseWriter, r *http.Request) { - logger.Debugf("access: %s", r.URL.String()) - - w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "text/plain;charset=utf-8") - for index, value := range os.Args { - if index > 0 { - if _, err := w.Write([]byte(value + " ")); err != nil { - logger.Errorf("failed to respond information: %v", err) - } - } - - } -} - func certFromFile(certFile string, keyFile string) (*tls.Certificate, error) { // cert.Certificate is a chain of one or more certificates, leaf first. cert, err := tls.LoadX509KeyPair(certFile, keyFile)