From 578d4a19e96d43a518e89de6371cf12dae12285f Mon Sep 17 00:00:00 2001 From: Alejandro Pedraza Date: Wed, 16 Dec 2020 17:46:14 -0500 Subject: [PATCH] Have the tap APIServer refresh its cert automatically (#5388) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Followup to #5282, fixes #5272 in its totality. This follows the same pattern as the injector/sp-validator webhooks, leveraging `FsCredsWatcher` to watch for changes in the cert files. To reuse code from the webhooks, we moved `updateCert()` to `creds_watcher.go`, and `run()` as well (which now is called `ProcessEvents()`). The `TestNewAPIServer` test in `apiserver_test.go` was removed as it really was just testing two things: (1) that `apiServerAuth` doesn't error which is already covered in the following test, and (2) that the golib call `net.Listen("tcp", addr)` doesn't error, which we're not interested in testing here. ## How to test To test that the injector/sp-validator functionality is still correct, you can refer to #5282 The steps below are similar, but focused towards the tap component: ```bash # Create some root cert $ step certificate create linkerd-tap.linkerd.svc ca.crt ca.key --profile root-ca --no-password --insecure # configure tap's caBundle to be that root cert $ cat > linkerd-overrides.yml << EOF tap: externalSecret: true caBundle: | < ca.crt contents> EOF # Install linkerd $ bin/linkerd install --config linkerd-overrides.yml | k apply -f - # Generate an intermediatery cert with short lifespan $ step certificate create linkerd-tap.linkerd.svc ca-int.crt ca-int.key --ca ca.crt --ca-key ca.key --profile intermediate-ca --not-after 4m --no-password --insecure --san linkerd-tap.linkerd.svc # Create the secret using that intermediate cert $ kubectl create secret tls \ linkerd-tap-k8s-tls \ --cert=ca-int.crt \ --key=ca-int.key \ --namespace=linkerd # Rollout the tap pod for it to pick the new secret $ k -n linkerd rollout restart deploy/linkerd-tap # Tap should work $ bin/linkerd tap -n linkerd deploy/linkerd-web req id=0:0 proxy=in src=10.42.0.15:33040 dst=10.42.0.11:9994 tls=true :method=GET :authority=10.42.0.11:9994 :path=/metrics rsp id=0:0 proxy=in src=10.42.0.15:33040 dst=10.42.0.11:9994 tls=true :status=200 latency=1779µs end id=0:0 proxy=in src=10.42.0.15:33040 dst=10.42.0.11:9994 tls=true duration=65µs response-length=1709B # Wait 5 minutes and rollout tap again $ k -n linkerd rollout restart deploy/linkerd-tap # You'll see in the logs that the cert expired: $ k -n linkerd logs -f deploy/linkerd-tap tap 2020/12/15 16:03:41 http: TLS handshake error from 127.0.0.1:45866: remote error: tls: bad certificate 2020/12/15 16:03:41 http: TLS handshake error from 127.0.0.1:45870: remote error: tls: bad certificate # Recreate the secret $ step certificate create linkerd-tap.linkerd.svc ca-int.crt ca-int.key --ca ca.crt --ca-key ca.key --profile intermediate-ca --not-after 4m --no-password --insecure --san linkerd-tap.linkerd.svc $ k -n linkerd delete secret linkerd-tap-k8s-tls $ kubectl create secret tls \ linkerd-tap-k8s-tls \ --cert=ca-int.crt \ --key=ca-int.key \ --namespace=linkerd # Wait a few moments and you'll see the certs got reloaded and tap is working again time="2020-12-15T16:03:42Z" level=info msg="Updated certificate" addr=":8089" component=apiserver ``` --- controller/cmd/tap/main.go | 18 +----- controller/tap/apiserver.go | 100 +++++++++++++++++++++---------- controller/tap/apiserver_test.go | 56 ++--------------- controller/webhook/server.go | 56 ++++------------- pkg/tls/creds_watcher.go | 67 ++++++++++++++++++--- 5 files changed, 146 insertions(+), 151 deletions(-) diff --git a/controller/cmd/tap/main.go b/controller/cmd/tap/main.go index 67d74c467..dae386674 100644 --- a/controller/cmd/tap/main.go +++ b/controller/cmd/tap/main.go @@ -2,7 +2,6 @@ package tap import ( "context" - "crypto/tls" "flag" "os" "os/signal" @@ -12,7 +11,6 @@ import ( "github.com/linkerd/linkerd2/controller/tap" "github.com/linkerd/linkerd2/pkg/admin" "github.com/linkerd/linkerd2/pkg/flags" - pkgK8s "github.com/linkerd/linkerd2/pkg/k8s" "github.com/linkerd/linkerd2/pkg/trace" log "github.com/sirupsen/logrus" ) @@ -28,8 +26,6 @@ func Main(args []string) { kubeConfigPath := cmd.String("kubeconfig", "", "path to kube config") controllerNamespace := cmd.String("controller-namespace", "linkerd", "namespace in which Linkerd is installed") tapPort := cmd.Uint("tap-port", 4190, "proxy tap port to connect to") - tlsCertPath := cmd.String("tls-cert", pkgK8s.MountPathTLSCrtPEM, "path to TLS Cert PEM") - tlsKeyPath := cmd.String("tls-key", pkgK8s.MountPathTLSKeyPEM, "path to TLS Key PEM") disableCommonNames := cmd.Bool("disable-common-names", false, "disable checks for Common Names (for development)") trustDomain := cmd.String("identity-trust-domain", defaultDomain, "configures the name suffix used for identities") @@ -70,24 +66,14 @@ func Main(args []string) { } grpcTapServer := tap.NewGrpcTapServer(*tapPort, *controllerNamespace, *trustDomain, k8sAPI) - // TODO: make this configurable for local development - cert, err := tls.LoadX509KeyPair(*tlsCertPath, *tlsKeyPath) - if err != nil { - log.Fatal(err.Error()) - } - - apiServer, apiLis, err := tap.NewAPIServer(ctx, *apiServerAddr, cert, k8sAPI, grpcTapServer, *disableCommonNames) + apiServer, err := tap.NewAPIServer(ctx, *apiServerAddr, k8sAPI, grpcTapServer, *disableCommonNames) if err != nil { log.Fatal(err.Error()) } k8sAPI.Sync(nil) // blocks until caches are synced - go func() { - log.Infof("starting APIServer on %s", *apiServerAddr) - apiServer.ServeTLS(apiLis, "", "") - }() - + go apiServer.Start(ctx) go admin.StartServer(*metricsAddr) <-stop diff --git a/controller/tap/apiserver.go b/controller/tap/apiserver.go index 0b4daf5f5..d9c329cca 100644 --- a/controller/tap/apiserver.go +++ b/controller/tap/apiserver.go @@ -8,19 +8,27 @@ import ( "fmt" "net" "net/http" + "sync/atomic" "github.com/julienschmidt/httprouter" "github.com/linkerd/linkerd2/controller/gen/controller/tap" "github.com/linkerd/linkerd2/controller/k8s" k8sutils "github.com/linkerd/linkerd2/pkg/k8s" + pkgk8s "github.com/linkerd/linkerd2/pkg/k8s" "github.com/linkerd/linkerd2/pkg/prometheus" + pkgTls "github.com/linkerd/linkerd2/pkg/tls" + "github.com/prometheus/common/log" "github.com/sirupsen/logrus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -type apiServer struct { +// APIServer holds the underlying http server and its config +type APIServer struct { + *http.Server + listener net.Listener router *httprouter.Router allowedNames []string + certValue *atomic.Value log *logrus.Entry } @@ -28,14 +36,23 @@ type apiServer struct { func NewAPIServer( ctx context.Context, addr string, - cert tls.Certificate, k8sAPI *k8s.API, grpcTapServer tap.TapServer, disableCommonNames bool, -) (*http.Server, net.Listener, error) { +) (*APIServer, error) { + updateEvent := make(chan struct{}) + errEvent := make(chan error) + watcher := pkgTls.NewFsCredsWatcher(pkgk8s.MountPathTLSBase, updateEvent, errEvent). + WithFilePaths(pkgk8s.MountPathTLSCrtPEM, pkgk8s.MountPathTLSKeyPEM) + go func() { + if err := watcher.StartWatching(ctx); err != nil { + log.Fatalf("Failed to start creds watcher: %s", err) + } + }() + clientCAPem, allowedNames, usernameHeader, groupHeader, err := apiServerAuth(ctx, k8sAPI) if err != nil { - return nil, nil, err + return nil, err } // for development @@ -48,6 +65,18 @@ func NewAPIServer( "addr": addr, }) + clientCertPool := x509.NewCertPool() + clientCertPool.AppendCertsFromPEM([]byte(clientCAPem)) + + httpServer := &http.Server{ + Addr: addr, + TLSConfig: &tls.Config{ + ClientAuth: tls.VerifyClientCertIfGiven, + ClientCAs: clientCertPool, + }, + } + + var emptyCert atomic.Value h := &handler{ k8sAPI: k8sAPI, usernameHeader: usernameHeader, @@ -56,39 +85,48 @@ func NewAPIServer( log: log, } - router := initRouter(h) - - server := &apiServer{ - router: router, - allowedNames: allowedNames, - log: log, - } - - clientCertPool := x509.NewCertPool() - clientCertPool.AppendCertsFromPEM([]byte(clientCAPem)) - - wrappedServer := prometheus.WithTelemetry(server) - - s := &http.Server{ - Addr: addr, - Handler: wrappedServer, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - ClientAuth: tls.VerifyClientCertIfGiven, - ClientCAs: clientCertPool, - }, - } - lis, err := net.Listen("tcp", addr) if err != nil { - log.Fatalf("net.Listen failed with: %s", err) + return nil, fmt.Errorf("net.Listen failed with: %s", err) } - return s, lis, nil + s := &APIServer{ + Server: httpServer, + listener: lis, + router: initRouter(h), + allowedNames: allowedNames, + certValue: &emptyCert, + log: log, + } + s.Handler = prometheus.WithTelemetry(s) + httpServer.TLSConfig.GetCertificate = s.getCertificate + + if err := watcher.UpdateCert(s.certValue); err != nil { + return nil, fmt.Errorf("Failed to initialized certificate: %s", err) + } + + go watcher.ProcessEvents(log, s.certValue, updateEvent, errEvent) + + return s, nil +} + +// Start starts the https server +func (a *APIServer) Start(ctx context.Context) { + a.log.Infof("starting APIServer on %s", a.Server.Addr) + if err := a.ServeTLS(a.listener, "", ""); err != nil { + if err == http.ErrServerClosed { + return + } + a.log.Fatal(err) + } +} + +func (a *APIServer) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + return a.certValue.Load().(*tls.Certificate), nil } // ServeHTTP handles all routes for the APIServer. -func (a *apiServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (a *APIServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { a.log.Debugf("ServeHTTP(): %+v", req) if err := a.validate(req); err != nil { a.log.Debug(err) @@ -99,7 +137,7 @@ func (a *apiServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { } // validate ensures that the request should be honored returning an error otherwise. -func (a *apiServer) validate(req *http.Request) error { +func (a *APIServer) validate(req *http.Request) error { // if `requestheader-allowed-names` was empty, allow any CN if len(a.allowedNames) > 0 { for _, cn := range a.allowedNames { diff --git a/controller/tap/apiserver_test.go b/controller/tap/apiserver_test.go index 52f53dc2e..1a007a696 100644 --- a/controller/tap/apiserver_test.go +++ b/controller/tap/apiserver_test.go @@ -16,54 +16,6 @@ import ( k8sutils "github.com/linkerd/linkerd2/pkg/k8s" ) -func TestNewAPIServer(t *testing.T) { - expectations := []struct { - k8sRes []string - err error - }{ - { - err: fmt.Errorf("failed to load [%s] config: configmaps %q not found", k8sutils.ExtensionAPIServerAuthenticationConfigMapName, k8sutils.ExtensionAPIServerAuthenticationConfigMapName), - }, - { - err: nil, - k8sRes: []string{` -apiVersion: v1 -kind: ConfigMap -metadata: - name: extension-apiserver-authentication - namespace: kube-system -data: - client-ca-file: 'client-ca-file' - requestheader-allowed-names: '["name1", "name2"]' - requestheader-client-ca-file: 'requestheader-client-ca-file' - requestheader-extra-headers-prefix: '["X-Remote-Extra-"]' - requestheader-group-headers: '["X-Remote-Group"]' - requestheader-username-headers: '["X-Remote-User"]' -`, - }, - }, - } - - ctx := context.Background() - for i, exp := range expectations { - exp := exp // pin - - t.Run(fmt.Sprintf("%d returns a configured API Server", i), func(t *testing.T) { - k8sAPI, err := k8s.NewFakeAPI(exp.k8sRes...) - if err != nil { - t.Fatalf("NewFakeAPI returned an error: %s", err) - } - - fakeGrpcServer := newGRPCTapServer(4190, "controller-ns", "cluster.local", k8sAPI) - - _, _, err = NewAPIServer(ctx, "localhost:0", tls.Certificate{}, k8sAPI, fakeGrpcServer, false) - if !reflect.DeepEqual(err, exp.err) { - t.Errorf("NewAPIServer returned unexpected error: %s, expected: %s", err, exp.err) - } - }) - } -} - func TestAPIServerAuth(t *testing.T) { expectations := []struct { k8sRes []string @@ -138,7 +90,7 @@ func TestValidate(t *testing.T) { req := http.Request{TLS: &tls} - server := apiServer{} + server := APIServer{} if err := server.validate(&req); err != nil { t.Fatalf("No error expected for %q but encountered %q", cert.Subject.CommonName, err.Error()) } @@ -152,7 +104,7 @@ func TestValidate_ClientAllowed(t *testing.T) { req := http.Request{TLS: &tls} - server := apiServer{allowedNames: []string{"name-trusted"}} + server := APIServer{allowedNames: []string{"name-trusted"}} if err := server.validate(&req); err != nil { t.Fatalf("No error expected for %q but encountered %q", cert.Subject.CommonName, err.Error()) } @@ -166,7 +118,7 @@ func TestValidate_ClientAllowedViaSAN(t *testing.T) { req := http.Request{TLS: &tls} - server := apiServer{allowedNames: []string{"linkerd.io"}} + server := APIServer{allowedNames: []string{"linkerd.io"}} if err := server.validate(&req); err != nil { t.Fatalf("No error expected for %q but encountered %q", cert.Subject.CommonName, err.Error()) } @@ -180,7 +132,7 @@ func TestValidate_ClientNotAllowed(t *testing.T) { req := http.Request{TLS: &tls} - server := apiServer{allowedNames: []string{"name-trusted"}} + server := APIServer{allowedNames: []string{"name-trusted"}} if err := server.validate(&req); err == nil { t.Fatalf("Expected request to be rejected for %q", cert.Subject.CommonName) } diff --git a/controller/webhook/server.go b/controller/webhook/server.go index b1588b0c1..542ff9f55 100644 --- a/controller/webhook/server.go +++ b/controller/webhook/server.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" "encoding/json" - "fmt" "io/ioutil" "net/http" "sync/atomic" @@ -12,6 +11,7 @@ import ( "github.com/linkerd/linkerd2/controller/k8s" pkgk8s "github.com/linkerd/linkerd2/pkg/k8s" pkgTls "github.com/linkerd/linkerd2/pkg/tls" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" admissionv1beta1 "k8s.io/api/admission/v1beta1" v1 "k8s.io/api/core/v1" @@ -36,7 +36,7 @@ type Server struct { *http.Server api *k8s.API handler Handler - certValue atomic.Value + certValue *atomic.Value recorder record.EventRecorder } @@ -50,7 +50,8 @@ func NewServer( ) (*Server, error) { updateEvent := make(chan struct{}) errEvent := make(chan error) - watcher := pkgTls.NewFsCredsWatcher(certPath, updateEvent, errEvent) + watcher := pkgTls.NewFsCredsWatcher(certPath, updateEvent, errEvent). + WithFilePaths(pkgk8s.MountPathTLSCrtPEM, pkgk8s.MountPathTLSKeyPEM) go func() { if err := watcher.StartWatching(ctx); err != nil { log.Fatalf("Failed to start creds watcher: %s", err) @@ -71,13 +72,16 @@ func NewServer( recorder := eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: component}) s := getConfiguredServer(server, api, handler, recorder) - if err := s.updateCert(); err != nil { + if err := watcher.UpdateCert(s.certValue); err != nil { log.Fatalf("Failed to initialized certificate: %s", err) } - go func() { - s.run(updateEvent, errEvent) - }() + log := logrus.WithFields(logrus.Fields{ + "component": "proxy-injector", + "addr": addr, + }) + + go watcher.ProcessEvents(log, s.certValue, updateEvent, errEvent) return s, nil } @@ -89,32 +93,12 @@ func getConfiguredServer( recorder record.EventRecorder, ) *Server { var emptyCert atomic.Value - s := &Server{httpServer, api, handler, emptyCert, recorder} + s := &Server{httpServer, api, handler, &emptyCert, recorder} s.Handler = http.HandlerFunc(s.serve) httpServer.TLSConfig.GetCertificate = s.getCertificate return s } -func (s *Server) updateCert() error { - creds, err := pkgTls.ReadPEMCreds( - pkgk8s.MountPathTLSKeyPEM, - pkgk8s.MountPathTLSCrtPEM, - ) - if err != nil { - return fmt.Errorf("failed to read cert from disk: %s", err) - } - - certPEM := creds.EncodePEM() - keyPEM := creds.EncodePrivateKeyPEM() - cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) - if err != nil { - return err - } - s.certValue.Store(&cert) - log.Debug("Certificate has been updated") - return nil -} - // Start starts the https server func (s *Server) Start() { log.Infof("listening at %s", s.Server.Addr) @@ -131,22 +115,6 @@ func (s *Server) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error return s.certValue.Load().(*tls.Certificate), nil } -// run reads from the update and error channels and reloads the certs when necessary -func (s *Server) run(updateEvent <-chan struct{}, errEvent <-chan error) { - for { - select { - case <-updateEvent: - if err := s.updateCert(); err != nil { - log.Warnf("Skipping update as cert could not be read from disk: %s", err) - } else { - log.Infof("Updated certificate") - } - case err := <-errEvent: - log.Warnf("Received error from fs watcher: %s", err) - } - } -} - func (s *Server) serve(res http.ResponseWriter, req *http.Request) { var ( data []byte diff --git a/pkg/tls/creds_watcher.go b/pkg/tls/creds_watcher.go index d519f0afd..53bfeef58 100644 --- a/pkg/tls/creds_watcher.go +++ b/pkg/tls/creds_watcher.go @@ -2,9 +2,13 @@ package tls import ( "context" + "crypto/tls" + "fmt" "path/filepath" + "sync/atomic" "github.com/fsnotify/fsnotify" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" ) @@ -12,14 +16,23 @@ const dataDirectoryLnName = "..data" // FsCredsWatcher is used to monitor tls credentials on the filesystem type FsCredsWatcher struct { - certPath string - EventChan chan<- struct{} - ErrorChan chan<- error + certRootPath string + certFilePath string + keyFilePath string + EventChan chan<- struct{} + ErrorChan chan<- error } // NewFsCredsWatcher constructs a FsCredsWatcher instance -func NewFsCredsWatcher(certPath string, updateEvent chan<- struct{}, errEvent chan<- error) *FsCredsWatcher { - return &FsCredsWatcher{certPath, updateEvent, errEvent} +func NewFsCredsWatcher(certRootPath string, updateEvent chan<- struct{}, errEvent chan<- error) *FsCredsWatcher { + return &FsCredsWatcher{certRootPath, "", "", updateEvent, errEvent} +} + +// WithFilePaths completes the FsCredsWatcher instance with the cert and key files locations +func (fscw *FsCredsWatcher) WithFilePaths(certFilePath, keyFilePath string) *FsCredsWatcher { + fscw.certFilePath = certFilePath + fscw.keyFilePath = keyFilePath + return fscw } // StartWatching starts watching the filesystem for cert updates @@ -31,7 +44,7 @@ func (fscw *FsCredsWatcher) StartWatching(ctx context.Context) error { defer watcher.Close() // no point of proceeding if we fail to watch this - if err := watcher.Add(fscw.certPath); err != nil { + if err := watcher.Add(fscw.certRootPath); err != nil { return err } @@ -43,12 +56,12 @@ LOOP: // Watching the folder for create events as this indicates // that the secret has been updated. if event.Op&fsnotify.Create == fsnotify.Create && - event.Name == filepath.Join(fscw.certPath, dataDirectoryLnName) { + event.Name == filepath.Join(fscw.certRootPath, dataDirectoryLnName) { fscw.EventChan <- struct{}{} } case err := <-watcher.Errors: fscw.ErrorChan <- err - log.Warnf("Error while watching %s: %s", fscw.certPath, err) + log.Warnf("Error while watching %s: %s", fscw.certRootPath, err) break LOOP case <-ctx.Done(): if err := ctx.Err(); err != nil { @@ -60,3 +73,41 @@ LOOP: return nil } + +// UpdateCert reads the cert and key files and stores the key pair in certVal +func (fscw *FsCredsWatcher) UpdateCert(certVal *atomic.Value) error { + creds, err := ReadPEMCreds(fscw.keyFilePath, fscw.certFilePath) + if err != nil { + return fmt.Errorf("failed to read cert from disk: %s", err) + } + + certPEM := creds.EncodePEM() + keyPEM := creds.EncodePrivateKeyPEM() + cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + return err + } + certVal.Store(&cert) + return nil +} + +// ProcessEvents reads from the update and error channels and reloads the certs when necessary +func (fscw *FsCredsWatcher) ProcessEvents( + log *logrus.Entry, + certVal *atomic.Value, + updateEvent <-chan struct{}, + errEvent <-chan error, +) { + for { + select { + case <-updateEvent: + if err := fscw.UpdateCert(certVal); err != nil { + log.Warnf("Skipping update as cert could not be read from disk: %s", err) + } else { + log.Infof("Updated certificate") + } + case err := <-errEvent: + log.Warnf("Received error from fs watcher: %s", err) + } + } +}