mirror of https://github.com/linkerd/linkerd2.git
Have the tap APIServer refresh its cert automatically (#5388)
Followup to #5282, fixes #5272 in its totality. This follows the same pattern as the injector/sp-validator webhooks, leveraging `FsCredsWatcher` to watch for changes in the cert files. To reuse code from the webhooks, we moved `updateCert()` to `creds_watcher.go`, and `run()` as well (which now is called `ProcessEvents()`). The `TestNewAPIServer` test in `apiserver_test.go` was removed as it really was just testing two things: (1) that `apiServerAuth` doesn't error which is already covered in the following test, and (2) that the golib call `net.Listen("tcp", addr)` doesn't error, which we're not interested in testing here. ## How to test To test that the injector/sp-validator functionality is still correct, you can refer to #5282 The steps below are similar, but focused towards the tap component: ```bash # Create some root cert $ step certificate create linkerd-tap.linkerd.svc ca.crt ca.key --profile root-ca --no-password --insecure # configure tap's caBundle to be that root cert $ cat > linkerd-overrides.yml << EOF tap: externalSecret: true caBundle: | < ca.crt contents> EOF # Install linkerd $ bin/linkerd install --config linkerd-overrides.yml | k apply -f - # Generate an intermediatery cert with short lifespan $ step certificate create linkerd-tap.linkerd.svc ca-int.crt ca-int.key --ca ca.crt --ca-key ca.key --profile intermediate-ca --not-after 4m --no-password --insecure --san linkerd-tap.linkerd.svc # Create the secret using that intermediate cert $ kubectl create secret tls \ linkerd-tap-k8s-tls \ --cert=ca-int.crt \ --key=ca-int.key \ --namespace=linkerd # Rollout the tap pod for it to pick the new secret $ k -n linkerd rollout restart deploy/linkerd-tap # Tap should work $ bin/linkerd tap -n linkerd deploy/linkerd-web req id=0:0 proxy=in src=10.42.0.15:33040 dst=10.42.0.11:9994 tls=true :method=GET :authority=10.42.0.11:9994 :path=/metrics rsp id=0:0 proxy=in src=10.42.0.15:33040 dst=10.42.0.11:9994 tls=true :status=200 latency=1779µs end id=0:0 proxy=in src=10.42.0.15:33040 dst=10.42.0.11:9994 tls=true duration=65µs response-length=1709B # Wait 5 minutes and rollout tap again $ k -n linkerd rollout restart deploy/linkerd-tap # You'll see in the logs that the cert expired: $ k -n linkerd logs -f deploy/linkerd-tap tap 2020/12/15 16:03:41 http: TLS handshake error from 127.0.0.1:45866: remote error: tls: bad certificate 2020/12/15 16:03:41 http: TLS handshake error from 127.0.0.1:45870: remote error: tls: bad certificate # Recreate the secret $ step certificate create linkerd-tap.linkerd.svc ca-int.crt ca-int.key --ca ca.crt --ca-key ca.key --profile intermediate-ca --not-after 4m --no-password --insecure --san linkerd-tap.linkerd.svc $ k -n linkerd delete secret linkerd-tap-k8s-tls $ kubectl create secret tls \ linkerd-tap-k8s-tls \ --cert=ca-int.crt \ --key=ca-int.key \ --namespace=linkerd # Wait a few moments and you'll see the certs got reloaded and tap is working again time="2020-12-15T16:03:42Z" level=info msg="Updated certificate" addr=":8089" component=apiserver ```
This commit is contained in:
parent
589f36c4c2
commit
578d4a19e9
|
@ -2,7 +2,6 @@ package tap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"flag"
|
"flag"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
@ -12,7 +11,6 @@ import (
|
||||||
"github.com/linkerd/linkerd2/controller/tap"
|
"github.com/linkerd/linkerd2/controller/tap"
|
||||||
"github.com/linkerd/linkerd2/pkg/admin"
|
"github.com/linkerd/linkerd2/pkg/admin"
|
||||||
"github.com/linkerd/linkerd2/pkg/flags"
|
"github.com/linkerd/linkerd2/pkg/flags"
|
||||||
pkgK8s "github.com/linkerd/linkerd2/pkg/k8s"
|
|
||||||
"github.com/linkerd/linkerd2/pkg/trace"
|
"github.com/linkerd/linkerd2/pkg/trace"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
@ -28,8 +26,6 @@ func Main(args []string) {
|
||||||
kubeConfigPath := cmd.String("kubeconfig", "", "path to kube config")
|
kubeConfigPath := cmd.String("kubeconfig", "", "path to kube config")
|
||||||
controllerNamespace := cmd.String("controller-namespace", "linkerd", "namespace in which Linkerd is installed")
|
controllerNamespace := cmd.String("controller-namespace", "linkerd", "namespace in which Linkerd is installed")
|
||||||
tapPort := cmd.Uint("tap-port", 4190, "proxy tap port to connect to")
|
tapPort := cmd.Uint("tap-port", 4190, "proxy tap port to connect to")
|
||||||
tlsCertPath := cmd.String("tls-cert", pkgK8s.MountPathTLSCrtPEM, "path to TLS Cert PEM")
|
|
||||||
tlsKeyPath := cmd.String("tls-key", pkgK8s.MountPathTLSKeyPEM, "path to TLS Key PEM")
|
|
||||||
disableCommonNames := cmd.Bool("disable-common-names", false, "disable checks for Common Names (for development)")
|
disableCommonNames := cmd.Bool("disable-common-names", false, "disable checks for Common Names (for development)")
|
||||||
trustDomain := cmd.String("identity-trust-domain", defaultDomain, "configures the name suffix used for identities")
|
trustDomain := cmd.String("identity-trust-domain", defaultDomain, "configures the name suffix used for identities")
|
||||||
|
|
||||||
|
@ -70,24 +66,14 @@ func Main(args []string) {
|
||||||
}
|
}
|
||||||
grpcTapServer := tap.NewGrpcTapServer(*tapPort, *controllerNamespace, *trustDomain, k8sAPI)
|
grpcTapServer := tap.NewGrpcTapServer(*tapPort, *controllerNamespace, *trustDomain, k8sAPI)
|
||||||
|
|
||||||
// TODO: make this configurable for local development
|
apiServer, err := tap.NewAPIServer(ctx, *apiServerAddr, k8sAPI, grpcTapServer, *disableCommonNames)
|
||||||
cert, err := tls.LoadX509KeyPair(*tlsCertPath, *tlsKeyPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
apiServer, apiLis, err := tap.NewAPIServer(ctx, *apiServerAddr, cert, k8sAPI, grpcTapServer, *disableCommonNames)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err.Error())
|
log.Fatal(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
k8sAPI.Sync(nil) // blocks until caches are synced
|
k8sAPI.Sync(nil) // blocks until caches are synced
|
||||||
|
|
||||||
go func() {
|
go apiServer.Start(ctx)
|
||||||
log.Infof("starting APIServer on %s", *apiServerAddr)
|
|
||||||
apiServer.ServeTLS(apiLis, "", "")
|
|
||||||
}()
|
|
||||||
|
|
||||||
go admin.StartServer(*metricsAddr)
|
go admin.StartServer(*metricsAddr)
|
||||||
|
|
||||||
<-stop
|
<-stop
|
||||||
|
|
|
@ -8,19 +8,27 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"github.com/linkerd/linkerd2/controller/gen/controller/tap"
|
"github.com/linkerd/linkerd2/controller/gen/controller/tap"
|
||||||
"github.com/linkerd/linkerd2/controller/k8s"
|
"github.com/linkerd/linkerd2/controller/k8s"
|
||||||
k8sutils "github.com/linkerd/linkerd2/pkg/k8s"
|
k8sutils "github.com/linkerd/linkerd2/pkg/k8s"
|
||||||
|
pkgk8s "github.com/linkerd/linkerd2/pkg/k8s"
|
||||||
"github.com/linkerd/linkerd2/pkg/prometheus"
|
"github.com/linkerd/linkerd2/pkg/prometheus"
|
||||||
|
pkgTls "github.com/linkerd/linkerd2/pkg/tls"
|
||||||
|
"github.com/prometheus/common/log"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
type apiServer struct {
|
// APIServer holds the underlying http server and its config
|
||||||
|
type APIServer struct {
|
||||||
|
*http.Server
|
||||||
|
listener net.Listener
|
||||||
router *httprouter.Router
|
router *httprouter.Router
|
||||||
allowedNames []string
|
allowedNames []string
|
||||||
|
certValue *atomic.Value
|
||||||
log *logrus.Entry
|
log *logrus.Entry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,14 +36,23 @@ type apiServer struct {
|
||||||
func NewAPIServer(
|
func NewAPIServer(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
addr string,
|
addr string,
|
||||||
cert tls.Certificate,
|
|
||||||
k8sAPI *k8s.API,
|
k8sAPI *k8s.API,
|
||||||
grpcTapServer tap.TapServer,
|
grpcTapServer tap.TapServer,
|
||||||
disableCommonNames bool,
|
disableCommonNames bool,
|
||||||
) (*http.Server, net.Listener, error) {
|
) (*APIServer, error) {
|
||||||
|
updateEvent := make(chan struct{})
|
||||||
|
errEvent := make(chan error)
|
||||||
|
watcher := pkgTls.NewFsCredsWatcher(pkgk8s.MountPathTLSBase, updateEvent, errEvent).
|
||||||
|
WithFilePaths(pkgk8s.MountPathTLSCrtPEM, pkgk8s.MountPathTLSKeyPEM)
|
||||||
|
go func() {
|
||||||
|
if err := watcher.StartWatching(ctx); err != nil {
|
||||||
|
log.Fatalf("Failed to start creds watcher: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
clientCAPem, allowedNames, usernameHeader, groupHeader, err := apiServerAuth(ctx, k8sAPI)
|
clientCAPem, allowedNames, usernameHeader, groupHeader, err := apiServerAuth(ctx, k8sAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// for development
|
// for development
|
||||||
|
@ -48,6 +65,18 @@ func NewAPIServer(
|
||||||
"addr": addr,
|
"addr": addr,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
clientCertPool := x509.NewCertPool()
|
||||||
|
clientCertPool.AppendCertsFromPEM([]byte(clientCAPem))
|
||||||
|
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
ClientAuth: tls.VerifyClientCertIfGiven,
|
||||||
|
ClientCAs: clientCertPool,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var emptyCert atomic.Value
|
||||||
h := &handler{
|
h := &handler{
|
||||||
k8sAPI: k8sAPI,
|
k8sAPI: k8sAPI,
|
||||||
usernameHeader: usernameHeader,
|
usernameHeader: usernameHeader,
|
||||||
|
@ -56,39 +85,48 @@ func NewAPIServer(
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
|
|
||||||
router := initRouter(h)
|
|
||||||
|
|
||||||
server := &apiServer{
|
|
||||||
router: router,
|
|
||||||
allowedNames: allowedNames,
|
|
||||||
log: log,
|
|
||||||
}
|
|
||||||
|
|
||||||
clientCertPool := x509.NewCertPool()
|
|
||||||
clientCertPool.AppendCertsFromPEM([]byte(clientCAPem))
|
|
||||||
|
|
||||||
wrappedServer := prometheus.WithTelemetry(server)
|
|
||||||
|
|
||||||
s := &http.Server{
|
|
||||||
Addr: addr,
|
|
||||||
Handler: wrappedServer,
|
|
||||||
TLSConfig: &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{cert},
|
|
||||||
ClientAuth: tls.VerifyClientCertIfGiven,
|
|
||||||
ClientCAs: clientCertPool,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", addr)
|
lis, err := net.Listen("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("net.Listen failed with: %s", err)
|
return nil, fmt.Errorf("net.Listen failed with: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, lis, nil
|
s := &APIServer{
|
||||||
|
Server: httpServer,
|
||||||
|
listener: lis,
|
||||||
|
router: initRouter(h),
|
||||||
|
allowedNames: allowedNames,
|
||||||
|
certValue: &emptyCert,
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
s.Handler = prometheus.WithTelemetry(s)
|
||||||
|
httpServer.TLSConfig.GetCertificate = s.getCertificate
|
||||||
|
|
||||||
|
if err := watcher.UpdateCert(s.certValue); err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to initialized certificate: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go watcher.ProcessEvents(log, s.certValue, updateEvent, errEvent)
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the https server
|
||||||
|
func (a *APIServer) Start(ctx context.Context) {
|
||||||
|
a.log.Infof("starting APIServer on %s", a.Server.Addr)
|
||||||
|
if err := a.ServeTLS(a.listener, "", ""); err != nil {
|
||||||
|
if err == http.ErrServerClosed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a.log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *APIServer) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
return a.certValue.Load().(*tls.Certificate), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeHTTP handles all routes for the APIServer.
|
// ServeHTTP handles all routes for the APIServer.
|
||||||
func (a *apiServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
func (a *APIServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
a.log.Debugf("ServeHTTP(): %+v", req)
|
a.log.Debugf("ServeHTTP(): %+v", req)
|
||||||
if err := a.validate(req); err != nil {
|
if err := a.validate(req); err != nil {
|
||||||
a.log.Debug(err)
|
a.log.Debug(err)
|
||||||
|
@ -99,7 +137,7 @@ func (a *apiServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate ensures that the request should be honored returning an error otherwise.
|
// validate ensures that the request should be honored returning an error otherwise.
|
||||||
func (a *apiServer) validate(req *http.Request) error {
|
func (a *APIServer) validate(req *http.Request) error {
|
||||||
// if `requestheader-allowed-names` was empty, allow any CN
|
// if `requestheader-allowed-names` was empty, allow any CN
|
||||||
if len(a.allowedNames) > 0 {
|
if len(a.allowedNames) > 0 {
|
||||||
for _, cn := range a.allowedNames {
|
for _, cn := range a.allowedNames {
|
||||||
|
|
|
@ -16,54 +16,6 @@ import (
|
||||||
k8sutils "github.com/linkerd/linkerd2/pkg/k8s"
|
k8sutils "github.com/linkerd/linkerd2/pkg/k8s"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewAPIServer(t *testing.T) {
|
|
||||||
expectations := []struct {
|
|
||||||
k8sRes []string
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
err: fmt.Errorf("failed to load [%s] config: configmaps %q not found", k8sutils.ExtensionAPIServerAuthenticationConfigMapName, k8sutils.ExtensionAPIServerAuthenticationConfigMapName),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
err: nil,
|
|
||||||
k8sRes: []string{`
|
|
||||||
apiVersion: v1
|
|
||||||
kind: ConfigMap
|
|
||||||
metadata:
|
|
||||||
name: extension-apiserver-authentication
|
|
||||||
namespace: kube-system
|
|
||||||
data:
|
|
||||||
client-ca-file: 'client-ca-file'
|
|
||||||
requestheader-allowed-names: '["name1", "name2"]'
|
|
||||||
requestheader-client-ca-file: 'requestheader-client-ca-file'
|
|
||||||
requestheader-extra-headers-prefix: '["X-Remote-Extra-"]'
|
|
||||||
requestheader-group-headers: '["X-Remote-Group"]'
|
|
||||||
requestheader-username-headers: '["X-Remote-User"]'
|
|
||||||
`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
for i, exp := range expectations {
|
|
||||||
exp := exp // pin
|
|
||||||
|
|
||||||
t.Run(fmt.Sprintf("%d returns a configured API Server", i), func(t *testing.T) {
|
|
||||||
k8sAPI, err := k8s.NewFakeAPI(exp.k8sRes...)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewFakeAPI returned an error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fakeGrpcServer := newGRPCTapServer(4190, "controller-ns", "cluster.local", k8sAPI)
|
|
||||||
|
|
||||||
_, _, err = NewAPIServer(ctx, "localhost:0", tls.Certificate{}, k8sAPI, fakeGrpcServer, false)
|
|
||||||
if !reflect.DeepEqual(err, exp.err) {
|
|
||||||
t.Errorf("NewAPIServer returned unexpected error: %s, expected: %s", err, exp.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAPIServerAuth(t *testing.T) {
|
func TestAPIServerAuth(t *testing.T) {
|
||||||
expectations := []struct {
|
expectations := []struct {
|
||||||
k8sRes []string
|
k8sRes []string
|
||||||
|
@ -138,7 +90,7 @@ func TestValidate(t *testing.T) {
|
||||||
|
|
||||||
req := http.Request{TLS: &tls}
|
req := http.Request{TLS: &tls}
|
||||||
|
|
||||||
server := apiServer{}
|
server := APIServer{}
|
||||||
if err := server.validate(&req); err != nil {
|
if err := server.validate(&req); err != nil {
|
||||||
t.Fatalf("No error expected for %q but encountered %q", cert.Subject.CommonName, err.Error())
|
t.Fatalf("No error expected for %q but encountered %q", cert.Subject.CommonName, err.Error())
|
||||||
}
|
}
|
||||||
|
@ -152,7 +104,7 @@ func TestValidate_ClientAllowed(t *testing.T) {
|
||||||
|
|
||||||
req := http.Request{TLS: &tls}
|
req := http.Request{TLS: &tls}
|
||||||
|
|
||||||
server := apiServer{allowedNames: []string{"name-trusted"}}
|
server := APIServer{allowedNames: []string{"name-trusted"}}
|
||||||
if err := server.validate(&req); err != nil {
|
if err := server.validate(&req); err != nil {
|
||||||
t.Fatalf("No error expected for %q but encountered %q", cert.Subject.CommonName, err.Error())
|
t.Fatalf("No error expected for %q but encountered %q", cert.Subject.CommonName, err.Error())
|
||||||
}
|
}
|
||||||
|
@ -166,7 +118,7 @@ func TestValidate_ClientAllowedViaSAN(t *testing.T) {
|
||||||
|
|
||||||
req := http.Request{TLS: &tls}
|
req := http.Request{TLS: &tls}
|
||||||
|
|
||||||
server := apiServer{allowedNames: []string{"linkerd.io"}}
|
server := APIServer{allowedNames: []string{"linkerd.io"}}
|
||||||
if err := server.validate(&req); err != nil {
|
if err := server.validate(&req); err != nil {
|
||||||
t.Fatalf("No error expected for %q but encountered %q", cert.Subject.CommonName, err.Error())
|
t.Fatalf("No error expected for %q but encountered %q", cert.Subject.CommonName, err.Error())
|
||||||
}
|
}
|
||||||
|
@ -180,7 +132,7 @@ func TestValidate_ClientNotAllowed(t *testing.T) {
|
||||||
|
|
||||||
req := http.Request{TLS: &tls}
|
req := http.Request{TLS: &tls}
|
||||||
|
|
||||||
server := apiServer{allowedNames: []string{"name-trusted"}}
|
server := APIServer{allowedNames: []string{"name-trusted"}}
|
||||||
if err := server.validate(&req); err == nil {
|
if err := server.validate(&req); err == nil {
|
||||||
t.Fatalf("Expected request to be rejected for %q", cert.Subject.CommonName)
|
t.Fatalf("Expected request to be rejected for %q", cert.Subject.CommonName)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -12,6 +11,7 @@ import (
|
||||||
"github.com/linkerd/linkerd2/controller/k8s"
|
"github.com/linkerd/linkerd2/controller/k8s"
|
||||||
pkgk8s "github.com/linkerd/linkerd2/pkg/k8s"
|
pkgk8s "github.com/linkerd/linkerd2/pkg/k8s"
|
||||||
pkgTls "github.com/linkerd/linkerd2/pkg/tls"
|
pkgTls "github.com/linkerd/linkerd2/pkg/tls"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
admissionv1beta1 "k8s.io/api/admission/v1beta1"
|
admissionv1beta1 "k8s.io/api/admission/v1beta1"
|
||||||
v1 "k8s.io/api/core/v1"
|
v1 "k8s.io/api/core/v1"
|
||||||
|
@ -36,7 +36,7 @@ type Server struct {
|
||||||
*http.Server
|
*http.Server
|
||||||
api *k8s.API
|
api *k8s.API
|
||||||
handler Handler
|
handler Handler
|
||||||
certValue atomic.Value
|
certValue *atomic.Value
|
||||||
recorder record.EventRecorder
|
recorder record.EventRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +50,8 @@ func NewServer(
|
||||||
) (*Server, error) {
|
) (*Server, error) {
|
||||||
updateEvent := make(chan struct{})
|
updateEvent := make(chan struct{})
|
||||||
errEvent := make(chan error)
|
errEvent := make(chan error)
|
||||||
watcher := pkgTls.NewFsCredsWatcher(certPath, updateEvent, errEvent)
|
watcher := pkgTls.NewFsCredsWatcher(certPath, updateEvent, errEvent).
|
||||||
|
WithFilePaths(pkgk8s.MountPathTLSCrtPEM, pkgk8s.MountPathTLSKeyPEM)
|
||||||
go func() {
|
go func() {
|
||||||
if err := watcher.StartWatching(ctx); err != nil {
|
if err := watcher.StartWatching(ctx); err != nil {
|
||||||
log.Fatalf("Failed to start creds watcher: %s", err)
|
log.Fatalf("Failed to start creds watcher: %s", err)
|
||||||
|
@ -71,13 +72,16 @@ func NewServer(
|
||||||
recorder := eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: component})
|
recorder := eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: component})
|
||||||
|
|
||||||
s := getConfiguredServer(server, api, handler, recorder)
|
s := getConfiguredServer(server, api, handler, recorder)
|
||||||
if err := s.updateCert(); err != nil {
|
if err := watcher.UpdateCert(s.certValue); err != nil {
|
||||||
log.Fatalf("Failed to initialized certificate: %s", err)
|
log.Fatalf("Failed to initialized certificate: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
log := logrus.WithFields(logrus.Fields{
|
||||||
s.run(updateEvent, errEvent)
|
"component": "proxy-injector",
|
||||||
}()
|
"addr": addr,
|
||||||
|
})
|
||||||
|
|
||||||
|
go watcher.ProcessEvents(log, s.certValue, updateEvent, errEvent)
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
@ -89,32 +93,12 @@ func getConfiguredServer(
|
||||||
recorder record.EventRecorder,
|
recorder record.EventRecorder,
|
||||||
) *Server {
|
) *Server {
|
||||||
var emptyCert atomic.Value
|
var emptyCert atomic.Value
|
||||||
s := &Server{httpServer, api, handler, emptyCert, recorder}
|
s := &Server{httpServer, api, handler, &emptyCert, recorder}
|
||||||
s.Handler = http.HandlerFunc(s.serve)
|
s.Handler = http.HandlerFunc(s.serve)
|
||||||
httpServer.TLSConfig.GetCertificate = s.getCertificate
|
httpServer.TLSConfig.GetCertificate = s.getCertificate
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) updateCert() error {
|
|
||||||
creds, err := pkgTls.ReadPEMCreds(
|
|
||||||
pkgk8s.MountPathTLSKeyPEM,
|
|
||||||
pkgk8s.MountPathTLSCrtPEM,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read cert from disk: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
certPEM := creds.EncodePEM()
|
|
||||||
keyPEM := creds.EncodePrivateKeyPEM()
|
|
||||||
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.certValue.Store(&cert)
|
|
||||||
log.Debug("Certificate has been updated")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start starts the https server
|
// Start starts the https server
|
||||||
func (s *Server) Start() {
|
func (s *Server) Start() {
|
||||||
log.Infof("listening at %s", s.Server.Addr)
|
log.Infof("listening at %s", s.Server.Addr)
|
||||||
|
@ -131,22 +115,6 @@ func (s *Server) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error
|
||||||
return s.certValue.Load().(*tls.Certificate), nil
|
return s.certValue.Load().(*tls.Certificate), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// run reads from the update and error channels and reloads the certs when necessary
|
|
||||||
func (s *Server) run(updateEvent <-chan struct{}, errEvent <-chan error) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-updateEvent:
|
|
||||||
if err := s.updateCert(); err != nil {
|
|
||||||
log.Warnf("Skipping update as cert could not be read from disk: %s", err)
|
|
||||||
} else {
|
|
||||||
log.Infof("Updated certificate")
|
|
||||||
}
|
|
||||||
case err := <-errEvent:
|
|
||||||
log.Warnf("Received error from fs watcher: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) serve(res http.ResponseWriter, req *http.Request) {
|
func (s *Server) serve(res http.ResponseWriter, req *http.Request) {
|
||||||
var (
|
var (
|
||||||
data []byte
|
data []byte
|
||||||
|
|
|
@ -2,9 +2,13 @@ package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,14 +16,23 @@ const dataDirectoryLnName = "..data"
|
||||||
|
|
||||||
// FsCredsWatcher is used to monitor tls credentials on the filesystem
|
// FsCredsWatcher is used to monitor tls credentials on the filesystem
|
||||||
type FsCredsWatcher struct {
|
type FsCredsWatcher struct {
|
||||||
certPath string
|
certRootPath string
|
||||||
EventChan chan<- struct{}
|
certFilePath string
|
||||||
ErrorChan chan<- error
|
keyFilePath string
|
||||||
|
EventChan chan<- struct{}
|
||||||
|
ErrorChan chan<- error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFsCredsWatcher constructs a FsCredsWatcher instance
|
// NewFsCredsWatcher constructs a FsCredsWatcher instance
|
||||||
func NewFsCredsWatcher(certPath string, updateEvent chan<- struct{}, errEvent chan<- error) *FsCredsWatcher {
|
func NewFsCredsWatcher(certRootPath string, updateEvent chan<- struct{}, errEvent chan<- error) *FsCredsWatcher {
|
||||||
return &FsCredsWatcher{certPath, updateEvent, errEvent}
|
return &FsCredsWatcher{certRootPath, "", "", updateEvent, errEvent}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFilePaths completes the FsCredsWatcher instance with the cert and key files locations
|
||||||
|
func (fscw *FsCredsWatcher) WithFilePaths(certFilePath, keyFilePath string) *FsCredsWatcher {
|
||||||
|
fscw.certFilePath = certFilePath
|
||||||
|
fscw.keyFilePath = keyFilePath
|
||||||
|
return fscw
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartWatching starts watching the filesystem for cert updates
|
// StartWatching starts watching the filesystem for cert updates
|
||||||
|
@ -31,7 +44,7 @@ func (fscw *FsCredsWatcher) StartWatching(ctx context.Context) error {
|
||||||
defer watcher.Close()
|
defer watcher.Close()
|
||||||
|
|
||||||
// no point of proceeding if we fail to watch this
|
// no point of proceeding if we fail to watch this
|
||||||
if err := watcher.Add(fscw.certPath); err != nil {
|
if err := watcher.Add(fscw.certRootPath); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,12 +56,12 @@ LOOP:
|
||||||
// Watching the folder for create events as this indicates
|
// Watching the folder for create events as this indicates
|
||||||
// that the secret has been updated.
|
// that the secret has been updated.
|
||||||
if event.Op&fsnotify.Create == fsnotify.Create &&
|
if event.Op&fsnotify.Create == fsnotify.Create &&
|
||||||
event.Name == filepath.Join(fscw.certPath, dataDirectoryLnName) {
|
event.Name == filepath.Join(fscw.certRootPath, dataDirectoryLnName) {
|
||||||
fscw.EventChan <- struct{}{}
|
fscw.EventChan <- struct{}{}
|
||||||
}
|
}
|
||||||
case err := <-watcher.Errors:
|
case err := <-watcher.Errors:
|
||||||
fscw.ErrorChan <- err
|
fscw.ErrorChan <- err
|
||||||
log.Warnf("Error while watching %s: %s", fscw.certPath, err)
|
log.Warnf("Error while watching %s: %s", fscw.certRootPath, err)
|
||||||
break LOOP
|
break LOOP
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
if err := ctx.Err(); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
|
@ -60,3 +73,41 @@ LOOP:
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateCert reads the cert and key files and stores the key pair in certVal
|
||||||
|
func (fscw *FsCredsWatcher) UpdateCert(certVal *atomic.Value) error {
|
||||||
|
creds, err := ReadPEMCreds(fscw.keyFilePath, fscw.certFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read cert from disk: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM := creds.EncodePEM()
|
||||||
|
keyPEM := creds.EncodePrivateKeyPEM()
|
||||||
|
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
certVal.Store(&cert)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessEvents reads from the update and error channels and reloads the certs when necessary
|
||||||
|
func (fscw *FsCredsWatcher) ProcessEvents(
|
||||||
|
log *logrus.Entry,
|
||||||
|
certVal *atomic.Value,
|
||||||
|
updateEvent <-chan struct{},
|
||||||
|
errEvent <-chan error,
|
||||||
|
) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-updateEvent:
|
||||||
|
if err := fscw.UpdateCert(certVal); err != nil {
|
||||||
|
log.Warnf("Skipping update as cert could not be read from disk: %s", err)
|
||||||
|
} else {
|
||||||
|
log.Infof("Updated certificate")
|
||||||
|
}
|
||||||
|
case err := <-errEvent:
|
||||||
|
log.Warnf("Received error from fs watcher: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue