diff --git a/controller/cmd/identity/main.go b/controller/cmd/identity/main.go index bcbc1c635..bc5459f38 100644 --- a/controller/cmd/identity/main.go +++ b/controller/cmd/identity/main.go @@ -116,7 +116,7 @@ func Main(args []string) { // // Create and start FS creds watcher // - watcher := idctl.NewFsCredsWatcher(*issuerPath, issuerEvent, issuerError) + watcher := tls.NewFsCredsWatcher(*issuerPath, issuerEvent, issuerError) go func() { if err := watcher.StartWatching(ctx); err != nil { log.Fatalf("Failed to start creds watcher: %s", err) diff --git a/controller/webhook/launcher.go b/controller/webhook/launcher.go index 67a93fcf5..e242aa768 100644 --- a/controller/webhook/launcher.go +++ b/controller/webhook/launcher.go @@ -10,7 +10,6 @@ import ( "github.com/linkerd/linkerd2/controller/k8s" "github.com/linkerd/linkerd2/pkg/admin" pkgk8s "github.com/linkerd/linkerd2/pkg/k8s" - "github.com/linkerd/linkerd2/pkg/tls" log "github.com/sirupsen/logrus" ) @@ -33,12 +32,7 @@ func Launch( log.Fatalf("failed to initialize Kubernetes API: %s", err) } - cred, err := tls.ReadPEMCreds(pkgk8s.MountPathTLSKeyPEM, pkgk8s.MountPathTLSCrtPEM) - if err != nil { - log.Fatalf("failed to read TLS secrets: %s", err) - } - - s, err := NewServer(k8sAPI, addr, cred, handler, component) + s, err := NewServer(ctx, k8sAPI, addr, pkgk8s.MountPathTLSBase, handler, component) if err != nil { log.Fatalf("failed to initialize the webhook server: %s", err) } diff --git a/controller/webhook/server.go b/controller/webhook/server.go index 604075f71..b1588b0c1 100644 --- a/controller/webhook/server.go +++ b/controller/webhook/server.go @@ -4,10 +4,13 @@ import ( "context" "crypto/tls" "encoding/json" + "fmt" "io/ioutil" "net/http" + "sync/atomic" "github.com/linkerd/linkerd2/controller/k8s" + pkgk8s "github.com/linkerd/linkerd2/pkg/k8s" pkgTls "github.com/linkerd/linkerd2/pkg/tls" log "github.com/sirupsen/logrus" admissionv1beta1 "k8s.io/api/admission/v1beta1" @@ -31,28 +34,32 @@ type Handler func( // Server describes the https server implementing the webhook type Server struct { *http.Server - api *k8s.API - handler Handler - recorder record.EventRecorder + api *k8s.API + handler Handler + certValue atomic.Value + recorder record.EventRecorder } // NewServer returns a new instance of Server -func NewServer(api *k8s.API, addr string, cred *pkgTls.Cred, handler Handler, component string) (*Server, error) { - var ( - certPEM = cred.EncodePEM() - keyPEM = cred.EncodePrivateKeyPEM() - ) - - cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) - if err != nil { - return nil, err - } +func NewServer( + ctx context.Context, + api *k8s.API, + addr, certPath string, + handler Handler, + component string, +) (*Server, error) { + updateEvent := make(chan struct{}) + errEvent := make(chan error) + watcher := pkgTls.NewFsCredsWatcher(certPath, updateEvent, errEvent) + go func() { + if err := watcher.StartWatching(ctx); err != nil { + log.Fatalf("Failed to start creds watcher: %s", err) + } + }() server := &http.Server{ - Addr: addr, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - }, + Addr: addr, + TLSConfig: &tls.Config{}, } eventBroadcaster := record.NewBroadcaster() @@ -63,11 +70,51 @@ func NewServer(api *k8s.API, addr string, cred *pkgTls.Cred, handler Handler, co }) recorder := eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: component}) - s := &Server{server, api, handler, recorder} - s.Handler = http.HandlerFunc(s.serve) + s := getConfiguredServer(server, api, handler, recorder) + if err := s.updateCert(); err != nil { + log.Fatalf("Failed to initialized certificate: %s", err) + } + + go func() { + s.run(updateEvent, errEvent) + }() + return s, nil } +func getConfiguredServer( + httpServer *http.Server, + api *k8s.API, + handler Handler, + recorder record.EventRecorder, +) *Server { + var emptyCert atomic.Value + s := &Server{httpServer, api, handler, emptyCert, recorder} + s.Handler = http.HandlerFunc(s.serve) + httpServer.TLSConfig.GetCertificate = s.getCertificate + return s +} + +func (s *Server) updateCert() error { + creds, err := pkgTls.ReadPEMCreds( + pkgk8s.MountPathTLSKeyPEM, + pkgk8s.MountPathTLSCrtPEM, + ) + if err != nil { + return fmt.Errorf("failed to read cert from disk: %s", err) + } + + certPEM := creds.EncodePEM() + keyPEM := creds.EncodePrivateKeyPEM() + cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + return err + } + s.certValue.Store(&cert) + log.Debug("Certificate has been updated") + return nil +} + // Start starts the https server func (s *Server) Start() { log.Infof("listening at %s", s.Server.Addr) @@ -79,6 +126,27 @@ func (s *Server) Start() { } } +// getCertificate provides the TLS server with the current cert +func (s *Server) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + return s.certValue.Load().(*tls.Certificate), nil +} + +// run reads from the update and error channels and reloads the certs when necessary +func (s *Server) run(updateEvent <-chan struct{}, errEvent <-chan error) { + for { + select { + case <-updateEvent: + if err := s.updateCert(); err != nil { + log.Warnf("Skipping update as cert could not be read from disk: %s", err) + } else { + log.Infof("Updated certificate") + } + case err := <-errEvent: + log.Warnf("Received error from fs watcher: %s", err) + } + } +} + func (s *Server) serve(res http.ResponseWriter, req *http.Request) { var ( data []byte diff --git a/controller/webhook/server_test.go b/controller/webhook/server_test.go index 0e93ce619..1e8e08827 100644 --- a/controller/webhook/server_test.go +++ b/controller/webhook/server_test.go @@ -3,6 +3,7 @@ package webhook import ( "bytes" "context" + "crypto/tls" "net/http" "net/http/httptest" "reflect" @@ -12,14 +13,18 @@ import ( "github.com/linkerd/linkerd2/controller/k8s" ) +var mockHTTPServer = &http.Server{ + Addr: ":0", + TLSConfig: &tls.Config{}, +} + func TestServe(t *testing.T) { t.Run("with empty http request body", func(t *testing.T) { k8sAPI, err := k8s.NewFakeAPI() if err != nil { panic(err) } - testServer := &Server{nil, k8sAPI, nil, nil} - + testServer := getConfiguredServer(mockHTTPServer, k8sAPI, nil, nil) in := bytes.NewReader(nil) request := httptest.NewRequest(http.MethodGet, "/", in) @@ -37,8 +42,7 @@ func TestServe(t *testing.T) { } func TestShutdown(t *testing.T) { - server := &http.Server{Addr: ":0"} - testServer := &Server{server, nil, nil, nil} + testServer := getConfiguredServer(mockHTTPServer, nil, nil, nil) go func() { if err := testServer.ListenAndServe(); err != nil { diff --git a/pkg/k8s/labels.go b/pkg/k8s/labels.go index 23435c4c7..1e7493b69 100644 --- a/pkg/k8s/labels.go +++ b/pkg/k8s/labels.go @@ -357,11 +357,14 @@ const ( // store identity credentials. MountPathEndEntity = MountPathBase + "/identity/end-entity" + // MountPathTLSBase is the path at which the TLS cert and key PEM files are mounted + MountPathTLSBase = MountPathBase + "/tls" + // MountPathTLSKeyPEM is the path at which the TLS key PEM file is mounted. - MountPathTLSKeyPEM = MountPathBase + "/tls/tls.key" + MountPathTLSKeyPEM = MountPathTLSBase + "/tls.key" // MountPathTLSCrtPEM is the path at which the TLS cert PEM file is mounted. - MountPathTLSCrtPEM = MountPathBase + "/tls/tls.crt" + MountPathTLSCrtPEM = MountPathTLSBase + "/tls.crt" // MountPathXtablesLock is the path at which the proxy init container mounts xtables // This is necessary for xtables-legacy support diff --git a/controller/identity/creds_watcher.go b/pkg/tls/creds_watcher.go similarity index 70% rename from controller/identity/creds_watcher.go rename to pkg/tls/creds_watcher.go index 9ee8222e4..d519f0afd 100644 --- a/controller/identity/creds_watcher.go +++ b/pkg/tls/creds_watcher.go @@ -1,4 +1,4 @@ -package identity +package tls import ( "context" @@ -12,14 +12,14 @@ const dataDirectoryLnName = "..data" // FsCredsWatcher is used to monitor tls credentials on the filesystem type FsCredsWatcher struct { - issuerPath string - EventChan chan<- struct{} - ErrorChan chan<- error + certPath string + EventChan chan<- struct{} + ErrorChan chan<- error } // NewFsCredsWatcher constructs a FsCredsWatcher instance -func NewFsCredsWatcher(issuerPath string, issuerEvent chan<- struct{}, issuerError chan<- error) *FsCredsWatcher { - return &FsCredsWatcher{issuerPath, issuerEvent, issuerError} +func NewFsCredsWatcher(certPath string, updateEvent chan<- struct{}, errEvent chan<- error) *FsCredsWatcher { + return &FsCredsWatcher{certPath, updateEvent, errEvent} } // StartWatching starts watching the filesystem for cert updates @@ -31,7 +31,7 @@ func (fscw *FsCredsWatcher) StartWatching(ctx context.Context) error { defer watcher.Close() // no point of proceeding if we fail to watch this - if err := watcher.Add(fscw.issuerPath); err != nil { + if err := watcher.Add(fscw.certPath); err != nil { return err } @@ -43,12 +43,12 @@ LOOP: // Watching the folder for create events as this indicates // that the secret has been updated. if event.Op&fsnotify.Create == fsnotify.Create && - event.Name == filepath.Join(fscw.issuerPath, dataDirectoryLnName) { + event.Name == filepath.Join(fscw.certPath, dataDirectoryLnName) { fscw.EventChan <- struct{}{} } case err := <-watcher.Errors: fscw.ErrorChan <- err - log.Warnf("Error while watching %s: %s", fscw.issuerPath, err) + log.Warnf("Error while watching %s: %s", fscw.certPath, err) break LOOP case <-ctx.Done(): if err := ctx.Err(); err != nil {