chore: avoid use default mux (#3346)
Signed-off-by: Jim Ma <majinjing3@gmail.com>
This commit is contained in:
parent
fae3613ff2
commit
d66cf41c38
|
|
@ -21,7 +21,6 @@ import (
|
|||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
|
@ -80,7 +79,7 @@ type Proxy struct {
|
|||
cacheRWMutex sync.RWMutex
|
||||
|
||||
// directHandler are used to handle non-proxy requests
|
||||
directHandler http.Handler
|
||||
directHandler *http.ServeMux
|
||||
|
||||
// transport is used to handle http proxy requests
|
||||
transport http.RoundTripper
|
||||
|
|
@ -175,25 +174,6 @@ func WithCert(cert *tls.Certificate) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithDirectHandler sets the handler for non-proxy requests
|
||||
func WithDirectHandler(h *http.ServeMux) Option {
|
||||
return func(p *Proxy) *Proxy {
|
||||
if p.registry == nil || p.registry.Remote == nil || p.registry.Remote.URL == nil {
|
||||
logger.Warnf("registry mirror url is empty, registry mirror feature is disabled")
|
||||
h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, fmt.Sprintf("registry mirror feature is disabled"), http.StatusNotFound)
|
||||
})
|
||||
p.directHandler = h
|
||||
return p
|
||||
}
|
||||
// Make sure the root handler of the given server mux is the
|
||||
// registry mirror reverse proxy
|
||||
h.HandleFunc("/", p.mirrorRegistry)
|
||||
p.directHandler = h
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
// WithRules sets the proxy rules
|
||||
func WithRules(rules []*config.ProxyRule) Option {
|
||||
return func(p *Proxy) *Proxy {
|
||||
|
|
@ -290,9 +270,26 @@ func NewProxy(options ...Option) (*Proxy, error) {
|
|||
if proxy.transport == nil {
|
||||
proxy.transport = proxy.newTransport(nil)
|
||||
}
|
||||
|
||||
// check register mirror config and register handler
|
||||
proxy.updateMirrorHandler()
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
func (proxy *Proxy) updateMirrorHandler() {
|
||||
h := proxy.directHandler
|
||||
if proxy.registry == nil || proxy.registry.Remote == nil || proxy.registry.Remote.URL == nil {
|
||||
logger.Warnf("registry mirror url is empty, registry mirror feature is disabled")
|
||||
h.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "registry mirror feature is disabled", http.StatusNotFound)
|
||||
})
|
||||
return
|
||||
}
|
||||
// Make sure the root handler of the given server mux is the
|
||||
// registry mirror reverse proxy
|
||||
h.HandleFunc("/", proxy.mirrorRegistry)
|
||||
}
|
||||
|
||||
func isBasicAuthMatch(basicAuth *config.BasicAuth, user, pass string) bool {
|
||||
usernameOK := subtle.ConstantTimeCompare([]byte(basicAuth.Username), []byte(user)) == 1
|
||||
passwordOK := subtle.ConstantTimeCompare([]byte(basicAuth.Password), []byte(pass)) == 1
|
||||
|
|
|
|||
|
|
@ -22,14 +22,10 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
schedulerv1 "d7y.io/api/v2/pkg/apis/scheduler/v1"
|
||||
|
|
@ -137,7 +133,6 @@ func NewProxyManager(peerHost *schedulerv1.PeerHost, peerTaskManager peer.TaskMa
|
|||
}
|
||||
|
||||
func (pm *proxyManager) Serve(listener net.Listener) error {
|
||||
_ = WithDirectHandler(newDirectHandler())(pm.Proxy)
|
||||
pm.Server.Handler = pm.Proxy
|
||||
return pm.Server.Serve(listener)
|
||||
}
|
||||
|
|
@ -179,59 +174,6 @@ func (pm *proxyManager) Watch(opt *config.ProxyOption) {
|
|||
}
|
||||
}
|
||||
|
||||
func newDirectHandler() *http.ServeMux {
|
||||
s := http.DefaultServeMux
|
||||
s.HandleFunc("/args", getArgs)
|
||||
s.HandleFunc("/env", getEnv)
|
||||
return s
|
||||
}
|
||||
|
||||
// getEnv returns the environments of dfdaemon.
|
||||
func getEnv(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Debugf("access: %s", r.URL.String())
|
||||
if err := json.NewEncoder(w).Encode(ensureStringKey(viper.AllSettings())); err != nil {
|
||||
logger.Errorf("failed to encode env json: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ensureStringKey recursively ensures all maps in the given interface are string,
|
||||
// to make the result marshalable by json. This is meant to be used with viper
|
||||
// settings, so only maps and slices are handled.
|
||||
func ensureStringKey(obj any) any {
|
||||
rt, rv := reflect.TypeOf(obj), reflect.ValueOf(obj)
|
||||
switch rt.Kind() {
|
||||
case reflect.Map:
|
||||
res := make(map[string]any)
|
||||
for _, k := range rv.MapKeys() {
|
||||
res[fmt.Sprintf("%v", k.Interface())] = ensureStringKey(rv.MapIndex(k).Interface())
|
||||
}
|
||||
return res
|
||||
case reflect.Slice:
|
||||
res := make([]any, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
res[i] = ensureStringKey(rv.Index(i).Interface())
|
||||
}
|
||||
return res
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
// getArgs returns all the arguments of command-line except the program name.
|
||||
func getArgs(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Debugf("access: %s", r.URL.String())
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "text/plain;charset=utf-8")
|
||||
for index, value := range os.Args {
|
||||
if index > 0 {
|
||||
if _, err := w.Write([]byte(value + " ")); err != nil {
|
||||
logger.Errorf("failed to respond information: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func certFromFile(certFile string, keyFile string) (*tls.Certificate, error) {
|
||||
// cert.Certificate is a chain of one or more certificates, leaf first.
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
|
|
|
|||
Loading…
Reference in New Issue