diff --git a/pkg/server/mux/pathrecorder.go b/pkg/server/mux/pathrecorder.go index 40a9e75cf..7a343b369 100644 --- a/pkg/server/mux/pathrecorder.go +++ b/pkg/server/mux/pathrecorder.go @@ -21,13 +21,21 @@ import ( "net/http" "runtime/debug" "sort" + "sync" + "sync/atomic" utilruntime "k8s.io/apimachinery/pkg/util/runtime" ) -// PathRecorderMux wraps a mux object and records the registered exposedPaths. It is _not_ go routine safe. +// PathRecorderMux wraps a mux object and records the registered exposedPaths. type PathRecorderMux struct { - mux *http.ServeMux + lock sync.Mutex + pathToHandler map[string]http.Handler + + // mux stores an *http.ServeMux and is used to handle the actual serving + mux atomic.Value + + // exposedPaths is the list of paths that should be shown at / exposedPaths []string // pathStacks holds the stacks of all registered paths. This allows us to show a more helpful message @@ -37,10 +45,15 @@ type PathRecorderMux struct { // NewPathRecorderMux creates a new PathRecorderMux with the given mux as the base mux. func NewPathRecorderMux() *PathRecorderMux { - return &PathRecorderMux{ - mux: http.NewServeMux(), - pathStacks: map[string]string{}, + ret := &PathRecorderMux{ + pathToHandler: map[string]http.Handler{}, + mux: atomic.Value{}, + exposedPaths: []string{}, + pathStacks: map[string]string{}, } + + ret.mux.Store(http.NewServeMux()) + return ret } // ListedPaths returns the registered handler exposedPaths. @@ -58,41 +71,81 @@ func (m *PathRecorderMux) trackCallers(path string) { m.pathStacks[path] = string(debug.Stack()) } +// refreshMuxLocked creates a new mux and must be called while locked. Otherwise the view of handlers may +// not be consistent +func (m *PathRecorderMux) refreshMuxLocked() { + mux := http.NewServeMux() + for path, handler := range m.pathToHandler { + mux.Handle(path, handler) + } + + m.mux.Store(mux) +} + +// Unregister removes a path from the mux. +func (m *PathRecorderMux) Unregister(path string) { + m.lock.Lock() + defer m.lock.Unlock() + + delete(m.pathToHandler, path) + delete(m.pathStacks, path) + for i := range m.exposedPaths { + if m.exposedPaths[i] == path { + m.exposedPaths = append(m.exposedPaths[:i], m.exposedPaths[i+1:]...) + break + } + } + + m.refreshMuxLocked() +} + // Handle registers the handler for the given pattern. // If a handler already exists for pattern, Handle panics. func (m *PathRecorderMux) Handle(path string, handler http.Handler) { + m.lock.Lock() + defer m.lock.Unlock() m.trackCallers(path) m.exposedPaths = append(m.exposedPaths, path) - m.mux.Handle(path, handler) + m.pathToHandler[path] = handler + m.refreshMuxLocked() } // HandleFunc registers the handler function for the given pattern. // If a handler already exists for pattern, Handle panics. func (m *PathRecorderMux) HandleFunc(path string, handler func(http.ResponseWriter, *http.Request)) { + m.lock.Lock() + defer m.lock.Unlock() m.trackCallers(path) m.exposedPaths = append(m.exposedPaths, path) - m.mux.HandleFunc(path, handler) + m.pathToHandler[path] = http.HandlerFunc(handler) + m.refreshMuxLocked() } // UnlistedHandle registers the handler for the given pattern, but doesn't list it. // If a handler already exists for pattern, Handle panics. func (m *PathRecorderMux) UnlistedHandle(path string, handler http.Handler) { + m.lock.Lock() + defer m.lock.Unlock() m.trackCallers(path) - m.mux.Handle(path, handler) + m.pathToHandler[path] = handler + m.refreshMuxLocked() } // UnlistedHandleFunc registers the handler function for the given pattern, but doesn't list it. // If a handler already exists for pattern, Handle panics. func (m *PathRecorderMux) UnlistedHandleFunc(path string, handler func(http.ResponseWriter, *http.Request)) { + m.lock.Lock() + defer m.lock.Unlock() m.trackCallers(path) - m.mux.HandleFunc(path, handler) + m.pathToHandler[path] = http.HandlerFunc(handler) + m.refreshMuxLocked() } // ServeHTTP makes it an http.Handler func (m *PathRecorderMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - m.mux.ServeHTTP(w, r) + m.mux.Load().(*http.ServeMux).ServeHTTP(w, r) } diff --git a/pkg/server/mux/pathrecorder_test.go b/pkg/server/mux/pathrecorder_test.go index 3d7e6b610..9bf64fe78 100644 --- a/pkg/server/mux/pathrecorder_test.go +++ b/pkg/server/mux/pathrecorder_test.go @@ -18,6 +18,7 @@ package mux import ( "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -30,3 +31,39 @@ func TestSecretHandlers(t *testing.T) { assert.NotContains(t, c.ListedPaths(), "/secret") assert.Contains(t, c.ListedPaths(), "/nonswagger") } + +func TestUnregisterHandlers(t *testing.T) { + first := 0 + second := 0 + + c := NewPathRecorderMux() + s := httptest.NewServer(c) + defer s.Close() + + c.UnlistedHandleFunc("/secret", func(http.ResponseWriter, *http.Request) {}) + c.HandleFunc("/nonswagger", func(http.ResponseWriter, *http.Request) { + first = first + 1 + }) + assert.NotContains(t, c.ListedPaths(), "/secret") + assert.Contains(t, c.ListedPaths(), "/nonswagger") + + resp, _ := http.Get(s.URL + "/nonswagger") + assert.Equal(t, first, 1) + assert.Equal(t, resp.StatusCode, http.StatusOK) + + c.Unregister("/nonswagger") + assert.NotContains(t, c.ListedPaths(), "/nonswagger") + + resp, _ = http.Get(s.URL + "/nonswagger") + assert.Equal(t, first, 1) + assert.Equal(t, resp.StatusCode, http.StatusNotFound) + + c.HandleFunc("/nonswagger", func(http.ResponseWriter, *http.Request) { + second = second + 1 + }) + assert.Contains(t, c.ListedPaths(), "/nonswagger") + resp, _ = http.Get(s.URL + "/nonswagger") + assert.Equal(t, first, 1) + assert.Equal(t, second, 1) + assert.Equal(t, resp.StatusCode, http.StatusOK) +}