From 8d3467626ee26cad48ad84f2181552dce7afccb6 Mon Sep 17 00:00:00 2001 From: David Calavera Date: Fri, 8 Apr 2016 16:22:39 -0700 Subject: [PATCH] Move middleware to interfaces. This makes separating middlewares from the core api easier. As an example, the authorization middleware is moved to it's own package. Initialize all static middlewares when the server is created, reducing allocations every time a route is wrapper with the middlewares. Signed-off-by: David Calavera --- api/server/middleware.go | 20 +-------- api/server/middleware/authorization.go | 50 --------------------- api/server/middleware/cors.go | 44 ++++++++++--------- api/server/middleware/debug.go | 2 +- api/server/middleware/middleware.go | 14 ++++-- api/server/middleware/user_agent.go | 54 +++++++++++++---------- api/server/middleware/version.go | 60 ++++++++++++++++---------- api/server/middleware/version_test.go | 8 ++-- api/server/server.go | 23 ++++++---- api/server/server_test.go | 5 +++ docker/daemon.go | 33 ++++++++++++-- pkg/authorization/middleware.go | 60 ++++++++++++++++++++++++++ 12 files changed, 218 insertions(+), 155 deletions(-) delete mode 100644 api/server/middleware/authorization.go create mode 100644 pkg/authorization/middleware.go diff --git a/api/server/middleware.go b/api/server/middleware.go index 31d18ab42c..108e3c077c 100644 --- a/api/server/middleware.go +++ b/api/server/middleware.go @@ -2,10 +2,8 @@ package server import ( "github.com/Sirupsen/logrus" - "github.com/docker/docker/api" "github.com/docker/docker/api/server/httputils" "github.com/docker/docker/api/server/middleware" - "github.com/docker/docker/pkg/authorization" ) // handleWithGlobalMiddlwares wraps the handler function for a request with @@ -14,27 +12,13 @@ import ( func (s *Server) handleWithGlobalMiddlewares(handler httputils.APIFunc) httputils.APIFunc { next := handler - handleVersion := middleware.NewVersionMiddleware(s.cfg.Version, api.DefaultVersion, api.MinVersion) - next = handleVersion(next) - - if s.cfg.EnableCors { - handleCORS := middleware.NewCORSMiddleware(s.cfg.CorsHeaders) - next = handleCORS(next) + for _, m := range s.middlewares { + next = m.WrapHandler(next) } - handleUserAgent := middleware.NewUserAgentMiddleware(s.cfg.Version) - next = handleUserAgent(next) - - // Only want this on debug level if s.cfg.Logging && logrus.GetLevel() == logrus.DebugLevel { next = middleware.DebugRequestMiddleware(next) } - if len(s.cfg.AuthorizationPluginNames) > 0 { - s.authZPlugins = authorization.NewPlugins(s.cfg.AuthorizationPluginNames) - handleAuthorization := middleware.NewAuthorizationMiddleware(s.authZPlugins) - next = handleAuthorization(next) - } - return next } diff --git a/api/server/middleware/authorization.go b/api/server/middleware/authorization.go deleted file mode 100644 index 0163d81fb8..0000000000 --- a/api/server/middleware/authorization.go +++ /dev/null @@ -1,50 +0,0 @@ -package middleware - -import ( - "net/http" - - "github.com/Sirupsen/logrus" - "github.com/docker/docker/api/server/httputils" - "github.com/docker/docker/pkg/authorization" - "golang.org/x/net/context" -) - -// NewAuthorizationMiddleware creates a new Authorization middleware. -func NewAuthorizationMiddleware(plugins []authorization.Plugin) Middleware { - return func(handler httputils.APIFunc) httputils.APIFunc { - return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { - - user := "" - userAuthNMethod := "" - - // Default authorization using existing TLS connection credentials - // FIXME: Non trivial authorization mechanisms (such as advanced certificate validations, kerberos support - // and ldap) will be extracted using AuthN feature, which is tracked under: - // https://github.com/docker/docker/pull/20883 - if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { - user = r.TLS.PeerCertificates[0].Subject.CommonName - userAuthNMethod = "TLS" - } - - authCtx := authorization.NewCtx(plugins, user, userAuthNMethod, r.Method, r.RequestURI) - - if err := authCtx.AuthZRequest(w, r); err != nil { - logrus.Errorf("AuthZRequest for %s %s returned error: %s", r.Method, r.RequestURI, err) - return err - } - - rw := authorization.NewResponseModifier(w) - - if err := handler(ctx, rw, r, vars); err != nil { - logrus.Errorf("Handler for %s %s returned error: %s", r.Method, r.RequestURI, err) - return err - } - - if err := authCtx.AuthZResponse(rw, r); err != nil { - logrus.Errorf("AuthZResponse for %s %s returned error: %s", r.Method, r.RequestURI, err) - return err - } - return nil - } - } -} diff --git a/api/server/middleware/cors.go b/api/server/middleware/cors.go index de21897d2c..ea725dbc72 100644 --- a/api/server/middleware/cors.go +++ b/api/server/middleware/cors.go @@ -4,30 +4,34 @@ import ( "net/http" "github.com/Sirupsen/logrus" - "github.com/docker/docker/api/server/httputils" "golang.org/x/net/context" ) -// NewCORSMiddleware creates a new CORS middleware. -func NewCORSMiddleware(defaultHeaders string) Middleware { - return func(handler httputils.APIFunc) httputils.APIFunc { - return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { - // If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*" - // otherwise, all head values will be passed to HTTP handler - corsHeaders := defaultHeaders - if corsHeaders == "" { - corsHeaders = "*" - } +// CORSMiddleware injects CORS headers to each request +// when it's configured. +type CORSMiddleware struct { + defaultHeaders string +} - writeCorsHeaders(w, r, corsHeaders) - return handler(ctx, w, r, vars) +// NewCORSMiddleware creates a new CORSMiddleware with default headers. +func NewCORSMiddleware(d string) CORSMiddleware { + return CORSMiddleware{defaultHeaders: d} +} + +// WrapHandler returns a new handler function wrapping the previous one in the request chain. +func (c CORSMiddleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + // If "api-cors-header" is not given, but "api-enable-cors" is true, we set cors to "*" + // otherwise, all head values will be passed to HTTP handler + corsHeaders := c.defaultHeaders + if corsHeaders == "" { + corsHeaders = "*" } + + logrus.Debugf("CORS header is enabled and set to: %s", corsHeaders) + w.Header().Add("Access-Control-Allow-Origin", corsHeaders) + w.Header().Add("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept, X-Registry-Auth") + w.Header().Add("Access-Control-Allow-Methods", "HEAD, GET, POST, DELETE, PUT, OPTIONS") + return handler(ctx, w, r, vars) } } - -func writeCorsHeaders(w http.ResponseWriter, r *http.Request, corsHeaders string) { - logrus.Debugf("CORS header is enabled and set to: %s", corsHeaders) - w.Header().Add("Access-Control-Allow-Origin", corsHeaders) - w.Header().Add("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept, X-Registry-Auth") - w.Header().Add("Access-Control-Allow-Methods", "HEAD, GET, POST, DELETE, PUT, OPTIONS") -} diff --git a/api/server/middleware/debug.go b/api/server/middleware/debug.go index be7056f6c6..6af8aa54d1 100644 --- a/api/server/middleware/debug.go +++ b/api/server/middleware/debug.go @@ -13,7 +13,7 @@ import ( ) // DebugRequestMiddleware dumps the request to logger -func DebugRequestMiddleware(handler httputils.APIFunc) httputils.APIFunc { +func DebugRequestMiddleware(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { logrus.Debugf("Calling %s %s", r.Method, r.RequestURI) diff --git a/api/server/middleware/middleware.go b/api/server/middleware/middleware.go index 588331ae7e..dc1f5bfa0d 100644 --- a/api/server/middleware/middleware.go +++ b/api/server/middleware/middleware.go @@ -1,7 +1,13 @@ package middleware -import "github.com/docker/docker/api/server/httputils" +import ( + "net/http" -// Middleware is an adapter to allow the use of ordinary functions as Docker API filters. -// Any function that has the appropriate signature can be registered as a middleware. -type Middleware func(handler httputils.APIFunc) httputils.APIFunc + "golang.org/x/net/context" +) + +// Middleware is an interface to allow the use of ordinary functions as Docker API filters. +// Any struct that has the appropriate signature can be registered as a middleware. +type Middleware interface { + WrapHandler(func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error +} diff --git a/api/server/middleware/user_agent.go b/api/server/middleware/user_agent.go index 188196bf63..7093830fe3 100644 --- a/api/server/middleware/user_agent.go +++ b/api/server/middleware/user_agent.go @@ -10,28 +10,38 @@ import ( "golang.org/x/net/context" ) -// NewUserAgentMiddleware creates a new UserAgent middleware. -func NewUserAgentMiddleware(versionCheck string) Middleware { - serverVersion := version.Version(versionCheck) +// UserAgentMiddleware is a middleware that +// validates the client user-agent. +type UserAgentMiddleware struct { + serverVersion version.Version +} - return func(handler httputils.APIFunc) httputils.APIFunc { - return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { - ctx = context.WithValue(ctx, httputils.UAStringKey, r.Header.Get("User-Agent")) - - if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") { - userAgent := strings.Split(r.Header.Get("User-Agent"), "/") - - // v1.20 onwards includes the GOOS of the client after the version - // such as Docker/1.7.0 (linux) - if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") { - userAgent[1] = strings.Split(userAgent[1], " ")[0] - } - - if len(userAgent) == 2 && !serverVersion.Equal(version.Version(userAgent[1])) { - logrus.Debugf("Client and server don't have the same version (client: %s, server: %s)", userAgent[1], serverVersion) - } - } - return handler(ctx, w, r, vars) - } +// NewUserAgentMiddleware creates a new UserAgentMiddleware +// with the server version. +func NewUserAgentMiddleware(s version.Version) UserAgentMiddleware { + return UserAgentMiddleware{ + serverVersion: s, + } +} + +// WrapHandler returns a new handler function wrapping the previous one in the request chain. +func (u UserAgentMiddleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + ctx = context.WithValue(ctx, httputils.UAStringKey, r.Header.Get("User-Agent")) + + if strings.Contains(r.Header.Get("User-Agent"), "Docker-Client/") { + userAgent := strings.Split(r.Header.Get("User-Agent"), "/") + + // v1.20 onwards includes the GOOS of the client after the version + // such as Docker/1.7.0 (linux) + if len(userAgent) == 2 && strings.Contains(userAgent[1], " ") { + userAgent[1] = strings.Split(userAgent[1], " ")[0] + } + + if len(userAgent) == 2 && !u.serverVersion.Equal(version.Version(userAgent[1])) { + logrus.Debugf("Client and server don't have the same version (client: %s, server: %s)", userAgent[1], u.serverVersion) + } + } + return handler(ctx, w, r, vars) } } diff --git a/api/server/middleware/version.go b/api/server/middleware/version.go index 41d518bcbc..d09c85a0fa 100644 --- a/api/server/middleware/version.go +++ b/api/server/middleware/version.go @@ -5,7 +5,6 @@ import ( "net/http" "runtime" - "github.com/docker/docker/api/server/httputils" "github.com/docker/docker/pkg/version" "golang.org/x/net/context" ) @@ -18,28 +17,43 @@ func (badRequestError) HTTPErrorStatusCode() int { return http.StatusBadRequest } -// NewVersionMiddleware creates a new Version middleware. -func NewVersionMiddleware(versionCheck string, defaultVersion, minVersion version.Version) Middleware { - serverVersion := version.Version(versionCheck) +// VersionMiddleware is a middleware that +// validates the client and server versions. +type VersionMiddleware struct { + serverVersion version.Version + defaultVersion version.Version + minVersion version.Version +} - return func(handler httputils.APIFunc) httputils.APIFunc { - return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { - apiVersion := version.Version(vars["version"]) - if apiVersion == "" { - apiVersion = defaultVersion - } - - if apiVersion.GreaterThan(defaultVersion) { - return badRequestError{fmt.Errorf("client is newer than server (client API version: %s, server API version: %s)", apiVersion, defaultVersion)} - } - if apiVersion.LessThan(minVersion) { - return badRequestError{fmt.Errorf("client version %s is too old. Minimum supported API version is %s, please upgrade your client to a newer version", apiVersion, minVersion)} - } - - header := fmt.Sprintf("Docker/%s (%s)", serverVersion, runtime.GOOS) - w.Header().Set("Server", header) - ctx = context.WithValue(ctx, httputils.APIVersionKey, apiVersion) - return handler(ctx, w, r, vars) - } +// NewVersionMiddleware creates a new VersionMiddleware +// with the default versions. +func NewVersionMiddleware(s, d, m version.Version) VersionMiddleware { + return VersionMiddleware{ + serverVersion: s, + defaultVersion: d, + minVersion: m, } } + +// WrapHandler returns a new handler function wrapping the previous one in the request chain. +func (v VersionMiddleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + apiVersion := version.Version(vars["version"]) + if apiVersion == "" { + apiVersion = v.defaultVersion + } + + if apiVersion.GreaterThan(v.defaultVersion) { + return badRequestError{fmt.Errorf("client is newer than server (client API version: %s, server API version: %s)", apiVersion, v.defaultVersion)} + } + if apiVersion.LessThan(v.minVersion) { + return badRequestError{fmt.Errorf("client version %s is too old. Minimum supported API version is %s, please upgrade your client to a newer version", apiVersion, v.minVersion)} + } + + header := fmt.Sprintf("Docker/%s (%s)", v.serverVersion, runtime.GOOS) + w.Header().Set("Server", header) + ctx = context.WithValue(ctx, "api-version", apiVersion) + return handler(ctx, w, r, vars) + } + +} diff --git a/api/server/middleware/version_test.go b/api/server/middleware/version_test.go index f60a98e518..89217b0387 100644 --- a/api/server/middleware/version_test.go +++ b/api/server/middleware/version_test.go @@ -21,8 +21,8 @@ func TestVersionMiddleware(t *testing.T) { defaultVersion := version.Version("1.10.0") minVersion := version.Version("1.2.0") - m := NewVersionMiddleware(defaultVersion.String(), defaultVersion, minVersion) - h := m(handler) + m := NewVersionMiddleware(defaultVersion, defaultVersion, minVersion) + h := m.WrapHandler(handler) req, _ := http.NewRequest("GET", "/containers/json", nil) resp := httptest.NewRecorder() @@ -42,8 +42,8 @@ func TestVersionMiddlewareWithErrors(t *testing.T) { defaultVersion := version.Version("1.10.0") minVersion := version.Version("1.2.0") - m := NewVersionMiddleware(defaultVersion.String(), defaultVersion, minVersion) - h := m(handler) + m := NewVersionMiddleware(defaultVersion, defaultVersion, minVersion) + h := m.WrapHandler(handler) req, _ := http.NewRequest("GET", "/containers/json", nil) resp := httptest.NewRecorder() diff --git a/api/server/server.go b/api/server/server.go index 1379b7372e..f406aeeba1 100644 --- a/api/server/server.go +++ b/api/server/server.go @@ -8,8 +8,8 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/docker/api/server/httputils" + "github.com/docker/docker/api/server/middleware" "github.com/docker/docker/api/server/router" - "github.com/docker/docker/pkg/authorization" "github.com/gorilla/mux" "golang.org/x/net/context" ) @@ -20,13 +20,12 @@ const versionMatcher = "/v{version:[0-9.]+}" // Config provides the configuration for the API server type Config struct { - Logging bool - EnableCors bool - CorsHeaders string - AuthorizationPluginNames []string - Version string - SocketGroup string - TLSConfig *tls.Config + Logging bool + EnableCors bool + CorsHeaders string + Version string + SocketGroup string + TLSConfig *tls.Config } // Server contains instance details for the server @@ -34,8 +33,8 @@ type Server struct { cfg *Config servers []*HTTPServer routers []router.Router - authZPlugins []authorization.Plugin routerSwapper *routerSwapper + middlewares []middleware.Middleware } // New returns a new instance of the server based on the specified configuration. @@ -46,6 +45,12 @@ func New(cfg *Config) *Server { } } +// UseMiddleware appends a new middleware to the request chain. +// This needs to be called before the API routes are configured. +func (s *Server) UseMiddleware(m middleware.Middleware) { + s.middlewares = append(s.middlewares, m) +} + // Accept sets a listener the server accepts connections into. func (s *Server) Accept(addr string, listeners ...net.Listener) { for _, listener := range listeners { diff --git a/api/server/server_test.go b/api/server/server_test.go index 9216804152..583283f569 100644 --- a/api/server/server_test.go +++ b/api/server/server_test.go @@ -6,7 +6,10 @@ import ( "strings" "testing" + "github.com/docker/docker/api" "github.com/docker/docker/api/server/httputils" + "github.com/docker/docker/api/server/middleware" + "github.com/docker/docker/pkg/version" "golang.org/x/net/context" ) @@ -19,6 +22,8 @@ func TestMiddlewares(t *testing.T) { cfg: cfg, } + srv.UseMiddleware(middleware.NewVersionMiddleware(version.Version("0.1omega2"), api.DefaultVersion, api.MinVersion)) + req, _ := http.NewRequest("GET", "/containers/json", nil) resp := httptest.NewRecorder() ctx := context.Background() diff --git a/docker/daemon.go b/docker/daemon.go index 804775667c..f4f2e78993 100644 --- a/docker/daemon.go +++ b/docker/daemon.go @@ -14,7 +14,9 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/distribution/uuid" + "github.com/docker/docker/api" apiserver "github.com/docker/docker/api/server" + "github.com/docker/docker/api/server/middleware" "github.com/docker/docker/api/server/router" "github.com/docker/docker/api/server/router/build" "github.com/docker/docker/api/server/router/container" @@ -29,12 +31,14 @@ import ( "github.com/docker/docker/dockerversion" "github.com/docker/docker/libcontainerd" "github.com/docker/docker/opts" + "github.com/docker/docker/pkg/authorization" "github.com/docker/docker/pkg/jsonlog" "github.com/docker/docker/pkg/listeners" flag "github.com/docker/docker/pkg/mflag" "github.com/docker/docker/pkg/pidfile" "github.com/docker/docker/pkg/signal" "github.com/docker/docker/pkg/system" + "github.com/docker/docker/pkg/version" "github.com/docker/docker/registry" "github.com/docker/docker/runconfig" "github.com/docker/docker/utils" @@ -208,10 +212,9 @@ func (cli *DaemonCli) CmdDaemon(args ...string) error { } serverConfig := &apiserver.Config{ - AuthorizationPluginNames: cli.Config.AuthorizationPlugins, - Logging: true, - SocketGroup: cli.Config.SocketGroup, - Version: dockerversion.Version, + Logging: true, + SocketGroup: cli.Config.SocketGroup, + Version: dockerversion.Version, } serverConfig = setPlatformServerConfig(serverConfig, cli.Config) @@ -288,6 +291,7 @@ func (cli *DaemonCli) CmdDaemon(args ...string) error { "graphdriver": d.GraphDriverName(), }).Info("Docker daemon") + cli.initMiddlewares(api, serverConfig) initRouter(api, d) reload := func(config *daemon.Config) { @@ -420,3 +424,24 @@ func initRouter(s *apiserver.Server, d *daemon.Daemon) { s.InitRouter(utils.IsDebugEnabled(), routers...) } + +func (cli *DaemonCli) initMiddlewares(s *apiserver.Server, cfg *apiserver.Config) { + v := version.Version(cfg.Version) + + vm := middleware.NewVersionMiddleware(v, api.DefaultVersion, api.MinVersion) + s.UseMiddleware(vm) + + if cfg.EnableCors { + c := middleware.NewCORSMiddleware(cfg.CorsHeaders) + s.UseMiddleware(c) + } + + u := middleware.NewUserAgentMiddleware(v) + s.UseMiddleware(u) + + if len(cli.Config.AuthorizationPlugins) > 0 { + authZPlugins := authorization.NewPlugins(cli.Config.AuthorizationPlugins) + handleAuthorization := authorization.NewMiddleware(authZPlugins) + s.UseMiddleware(handleAuthorization) + } +} diff --git a/pkg/authorization/middleware.go b/pkg/authorization/middleware.go new file mode 100644 index 0000000000..73511a8148 --- /dev/null +++ b/pkg/authorization/middleware.go @@ -0,0 +1,60 @@ +package authorization + +import ( + "net/http" + + "github.com/Sirupsen/logrus" + "golang.org/x/net/context" +) + +// Middleware uses a list of plugins to +// handle authorization in the API requests. +type Middleware struct { + plugins []Plugin +} + +// NewMiddleware creates a new Middleware +// with a slice of plugins. +func NewMiddleware(p []Plugin) Middleware { + return Middleware{ + plugins: p, + } +} + +// WrapHandler returns a new handler function wrapping the previous one in the request chain. +func (m Middleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + + user := "" + userAuthNMethod := "" + + // Default authorization using existing TLS connection credentials + // FIXME: Non trivial authorization mechanisms (such as advanced certificate validations, kerberos support + // and ldap) will be extracted using AuthN feature, which is tracked under: + // https://github.com/docker/docker/pull/20883 + if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { + user = r.TLS.PeerCertificates[0].Subject.CommonName + userAuthNMethod = "TLS" + } + + authCtx := NewCtx(m.plugins, user, userAuthNMethod, r.Method, r.RequestURI) + + if err := authCtx.AuthZRequest(w, r); err != nil { + logrus.Errorf("AuthZRequest for %s %s returned error: %s", r.Method, r.RequestURI, err) + return err + } + + rw := NewResponseModifier(w) + + if err := handler(ctx, rw, r, vars); err != nil { + logrus.Errorf("Handler for %s %s returned error: %s", r.Method, r.RequestURI, err) + return err + } + + if err := authCtx.AuthZResponse(rw, r); err != nil { + logrus.Errorf("AuthZResponse for %s %s returned error: %s", r.Method, r.RequestURI, err) + return err + } + return nil + } +}