[5756 Identity Management Overhaul] Injector & daprd (#6761)

* daprd & injector SPIFFE

Signed-off-by: joshvanl <me@joshvanl.dev>

* go mod tidy

Signed-off-by: joshvanl <me@joshvanl.dev>

* Revert placement chart changes

Signed-off-by: joshvanl <me@joshvanl.dev>

* Use correct placement ID in placement gRPC client

Signed-off-by: joshvanl <me@joshvanl.dev>

* Fix actors/placement test deadlock

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Delete unused `pkg/credentials` package

Signed-off-by: joshvanl <me@joshvanl.dev>

* Fix gRPC API server serving TLS

Signed-off-by: joshvanl <me@joshvanl.dev>

* Review comments

Signed-off-by: joshvanl <me@joshvanl.dev>

* Updates error string match now it is SPIFFE ID

Signed-off-by: joshvanl <me@joshvanl.dev>

* Returns `GetHTTPClient` e2e test func

Signed-off-by: joshvanl <me@joshvanl.dev>

* Revert injector to set daprd cert env vars by requesting them from
sentry on patch

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds back descriptions to deprecated daprd flags

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Use port 443 rather than 80 by default in sentry injector deployment
chart

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Default fake security values

Signed-off-by: joshvanl <me@joshvanl.dev>

* Remove unused cert CLI flags completely from injector

Signed-off-by: joshvanl <me@joshvanl.dev>

* Adds context to acl SPIFFE debug log

Signed-off-by: joshvanl <me@joshvanl.dev>

* Use correct debug log func in acl

Signed-off-by: joshvanl <me@joshvanl.dev>

* Name injector service func sigatures

Signed-off-by: joshvanl <me@joshvanl.dev>

* Linting

Signed-off-by: joshvanl <me@joshvanl.dev>

* Use legacy client for dialing sentry to allow backwards control plane
compat

Signed-off-by: joshvanl <me@joshvanl.dev>

* Returns `WithBlock` to placement client by default

Signed-off-by: joshvanl <me@joshvanl.dev>

* Revert placement client WithBlock

Signed-off-by: joshvanl <me@joshvanl.dev>

* Default to empty trust domain in legacy tls config if SVID not present
in peer certificate

Signed-off-by: joshvanl <me@joshvanl.dev>

* Use a non client auth client for sentry certificate request connection

Signed-off-by: joshvanl <me@joshvanl.dev>

* Use sentry client auth with certificate and private if available in
environment variable to satisfy v1.11 sentry server

Signed-off-by: joshvanl <me@joshvanl.dev>

* Add `public` Trust Domain in sentry request for daprd as is required for
v1.11 sentry.

Signed-off-by: joshvanl <me@joshvanl.dev>

* linting

Signed-off-by: joshvanl <me@joshvanl.dev>

---------

Signed-off-by: joshvanl <me@joshvanl.dev>
Co-authored-by: Dapr Bot <56698301+dapr-bot@users.noreply.github.com>
Co-authored-by: Yaron Schneider <schneider.yaron@live.com>
This commit is contained in:
Josh van Leeuwen 2023-09-11 19:53:33 -04:00 committed by GitHub
parent 94b652fb8a
commit c338385837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
82 changed files with 1643 additions and 1915 deletions

View File

@ -329,7 +329,6 @@ TEST_WITH_RACE=./pkg/acl/... \
./pkg/concurrency/... \
./pkg/config/... \
./pkg/cors/... \
./pkg/credentials/... \
./pkg/diagnostics/... \
./pkg/encryption/... \
./pkg/expr/... \

View File

@ -20,6 +20,10 @@ rules:
- apiGroups: [""]
resources: ["serviceaccounts"]
verbs: ["get", "list"]
- apiGroups: ["admissionregistration.k8s.io"]
resources: ["mutatingwebhookconfigurations"]
verbs: ["patch"]
resourceNames: ["dapr-sidecar-injector"]
{{- if not .Values.global.rbac.namespaced }}
- apiGroups: ["dapr.io"]
resources: ["configurations", "components"]

View File

@ -29,6 +29,7 @@ spec:
{{- toYaml . | nindent 8 }}
{{- end }}
annotations:
dapr.io/control-plane: injector
{{- if eq .Values.global.prometheus.enabled true }}
prometheus.io/scrape: "{{ .Values.global.prometheus.enabled }}"
prometheus.io/port: "{{ .Values.global.prometheus.port }}"
@ -122,19 +123,13 @@ spec:
{{- end }}
- "--healthz-port"
- "{{ .Values.healthzPort }}"
{{- with .Values.global.issuerFilenames }}
- "--issuer-ca-secret-key"
- "{{ .ca }}"
- "--issuer-certificate-secret-key"
- "{{ .cert }}"
- "--issuer-key-secret-key"
- "{{ .key }}"
{{- end }}
env:
- name: TLS_CERT_FILE
value: /dapr/cert/tls.crt
- name: TLS_KEY_FILE
value: /dapr/cert/tls.key
- name: DAPR_TRUST_ANCHORS_FILE
value: /var/run/secrets/dapr.io/tls/ca.crt
- name: DAPR_CONTROL_PLANE_TRUST_DOMAIN
value: {{ .Values.global.mtls.controlPlaneTrustDomain }}
- name: DAPR_SENTRY_ADDRESS
value: {{ if .Values.global.mtls.sentryAddress }}{{ .Values.global.mtls.sentryAddress }}{{ else }}dapr-sentry.{{ .Release.Namespace }}.svc.cluster.local:443{{ end }}
{{- if .Values.kubeClusterDomain }}
- name: KUBE_CLUSTER_DOMAIN
value: "{{ .Values.kubeClusterDomain }}"
@ -194,13 +189,23 @@ spec:
resources:
{{ toYaml .Values.resources | indent 10 }}
volumeMounts:
- name: cert
mountPath: /dapr/cert
- name: dapr-trust-bundle
mountPath: /var/run/secrets/dapr.io/tls
readOnly: true
- name: dapr-identity-token
mountPath: /var/run/secrets/dapr.io/sentrytoken
readOnly: true
volumes:
- name: cert
secret:
secretName: dapr-sidecar-injector-cert
- name: dapr-trust-bundle
configMap:
name: dapr-trust-bundle
- name: dapr-identity-token
projected:
sources:
- serviceAccountToken:
path: token
expirationSeconds: 600
audience: "spiffe://{{ .Values.global.mtls.controlPlaneTrustDomain }}/ns/{{ .Release.Namespace }}/dapr-sentry"
affinity:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:

View File

@ -1,32 +1,4 @@
{{- if eq .Values.enabled true }}
{{- $existingSecret := lookup "v1" "Secret" .Release.Namespace "dapr-sidecar-injector-cert"}}
{{- $existingWebHookConfig := lookup "admissionregistration.k8s.io/v1" "MutatingWebhookConfiguration" .Release.Namespace "dapr-sidecar-injector"}}
{{- $ca := genCA "dapr-sidecar-injector-ca" 3650 }}
{{- $cn := printf "dapr-sidecar-injector" }}
{{- $altName1 := printf "dapr-sidecar-injector.%s" .Release.Namespace }}
{{- $altName2 := printf "dapr-sidecar-injector.%s.svc" .Release.Namespace }}
{{- $altName3 := printf "dapr-sidecar-injector.%s.svc.cluster" .Release.Namespace }}
{{- $altName4 := printf "dapr-sidecar-injector.%s.svc.cluster.local" .Release.Namespace }}
{{- $cert := genSignedCert $cn nil (list $altName1 $altName2 $altName3 $altName4) 3650 $ca }}
apiVersion: v1
kind: Secret
metadata:
name: dapr-sidecar-injector-cert
namespace: {{ .Release.Namespace }}
labels:
app: dapr-sidecar-injector
{{- range $key, $value := .Values.global.k8sLabels }}
{{ $key }}: {{ tpl $value $ }}
{{- end }}
data:
{{ if $existingSecret }}tls.crt: {{ index $existingSecret.data "tls.crt" }}
{{ else }}tls.crt: {{ b64enc $cert.Cert }}
{{ end }}
{{ if $existingSecret }}tls.key: {{ index $existingSecret.data "tls.key" }}
{{ else }}tls.key: {{ b64enc $cert.Key }}
{{ end }}
---
apiVersion: admissionregistration.k8s.io/v1
kind: MutatingWebhookConfiguration
metadata:
@ -44,7 +16,6 @@ webhooks:
namespace: {{ .Release.Namespace }}
name: dapr-sidecar-injector
path: "/mutate"
caBundle: {{ if $existingWebHookConfig }}{{ (index $existingWebHookConfig.webhooks 0).clientConfig.caBundle }}{{ else }}{{ b64enc $ca.Cert }}{{ end }}
rules:
- apiGroups:
- ""

View File

@ -14,6 +14,7 @@ limitations under the License.
package main
import (
"context"
"fmt"
"os"
@ -34,7 +35,9 @@ import (
secretstoresLoader "github.com/dapr/dapr/pkg/components/secretstores"
stateLoader "github.com/dapr/dapr/pkg/components/state"
workflowsLoader "github.com/dapr/dapr/pkg/components/workflows"
"github.com/dapr/dapr/pkg/concurrency"
"github.com/dapr/dapr/pkg/runtime/registry"
"github.com/dapr/dapr/pkg/security"
"github.com/dapr/dapr/pkg/signals"
"github.com/dapr/dapr/pkg/runtime"
@ -103,49 +106,72 @@ func main() {
WithWorkflows(workflowsLoader.DefaultRegistry)
ctx := signals.Context()
rt, err := runtime.FromConfig(ctx, &runtime.Config{
AppID: opts.AppID,
PlacementServiceHostAddr: opts.PlacementServiceHostAddr,
AllowedOrigins: opts.AllowedOrigins,
ResourcesPath: opts.ResourcesPath,
ControlPlaneAddress: opts.ControlPlaneAddress,
AppProtocol: opts.AppProtocol,
Mode: opts.Mode,
DaprHTTPPort: opts.DaprHTTPPort,
DaprInternalGRPCPort: opts.DaprInternalGRPCPort,
DaprAPIGRPCPort: opts.DaprAPIGRPCPort,
DaprAPIListenAddresses: opts.DaprAPIListenAddresses,
DaprPublicPort: opts.DaprPublicPort,
ApplicationPort: opts.AppPort,
ProfilePort: opts.ProfilePort,
EnableProfiling: opts.EnableProfiling,
AppMaxConcurrency: opts.AppMaxConcurrency,
EnableMTLS: opts.EnableMTLS,
SentryAddress: opts.SentryAddress,
DaprHTTPMaxRequestSize: opts.DaprHTTPMaxRequestSize,
UnixDomainSocket: opts.UnixDomainSocket,
DaprHTTPReadBufferSize: opts.DaprHTTPReadBufferSize,
DaprGracefulShutdownSeconds: opts.DaprGracefulShutdownSeconds,
DisableBuiltinK8sSecretStore: opts.DisableBuiltinK8sSecretStore,
EnableAppHealthCheck: opts.EnableAppHealthCheck,
AppHealthCheckPath: opts.AppHealthCheckPath,
AppHealthProbeInterval: opts.AppHealthProbeInterval,
AppHealthProbeTimeout: opts.AppHealthProbeTimeout,
AppHealthThreshold: opts.AppHealthThreshold,
AppChannelAddress: opts.AppChannelAddress,
EnableAPILogging: opts.EnableAPILogging,
Config: opts.Config,
Metrics: opts.Metrics,
AppSSL: opts.AppSSL,
ComponentsPath: opts.ComponentsPath,
Registry: reg,
secProvider, err := security.New(ctx, security.Options{
SentryAddress: opts.SentryAddress,
ControlPlaneTrustDomain: opts.ControlPlaneTrustDomain,
ControlPlaneNamespace: opts.ControlPlaneNamespace,
TrustAnchors: opts.TrustAnchors,
AppID: opts.AppID,
MTLSEnabled: opts.EnableMTLS,
})
if err != nil {
log.Fatal(err)
}
if err := rt.Run(ctx); err != nil {
log.Fatalf("fatal error from runtime: %s", err)
err = concurrency.NewRunnerManager(
secProvider.Run,
func(ctx context.Context) error {
sec, serr := secProvider.Handler(ctx)
if serr != nil {
return serr
}
rt, rerr := runtime.FromConfig(ctx, &runtime.Config{
AppID: opts.AppID,
PlacementServiceHostAddr: opts.PlacementServiceHostAddr,
AllowedOrigins: opts.AllowedOrigins,
ResourcesPath: opts.ResourcesPath,
ControlPlaneAddress: opts.ControlPlaneAddress,
AppProtocol: opts.AppProtocol,
Mode: opts.Mode,
DaprHTTPPort: opts.DaprHTTPPort,
DaprInternalGRPCPort: opts.DaprInternalGRPCPort,
DaprAPIGRPCPort: opts.DaprAPIGRPCPort,
DaprAPIListenAddresses: opts.DaprAPIListenAddresses,
DaprPublicPort: opts.DaprPublicPort,
ApplicationPort: opts.AppPort,
ProfilePort: opts.ProfilePort,
EnableProfiling: opts.EnableProfiling,
AppMaxConcurrency: opts.AppMaxConcurrency,
EnableMTLS: opts.EnableMTLS,
SentryAddress: opts.SentryAddress,
DaprHTTPMaxRequestSize: opts.DaprHTTPMaxRequestSize,
UnixDomainSocket: opts.UnixDomainSocket,
DaprHTTPReadBufferSize: opts.DaprHTTPReadBufferSize,
DaprGracefulShutdownSeconds: opts.DaprGracefulShutdownSeconds,
DisableBuiltinK8sSecretStore: opts.DisableBuiltinK8sSecretStore,
EnableAppHealthCheck: opts.EnableAppHealthCheck,
AppHealthCheckPath: opts.AppHealthCheckPath,
AppHealthProbeInterval: opts.AppHealthProbeInterval,
AppHealthProbeTimeout: opts.AppHealthProbeTimeout,
AppHealthThreshold: opts.AppHealthThreshold,
AppChannelAddress: opts.AppChannelAddress,
EnableAPILogging: opts.EnableAPILogging,
Config: opts.Config,
Metrics: opts.Metrics,
AppSSL: opts.AppSSL,
ComponentsPath: opts.ComponentsPath,
Registry: reg,
Security: sec,
})
if rerr != nil {
return rerr
}
return rt.Run(ctx)
},
).Run(ctx)
if err != nil {
log.Fatalf("Fatal error from runtime: %s", err)
}
}

View File

@ -15,6 +15,7 @@ package options
import (
"flag"
"os"
"strconv"
"strings"
"time"
@ -25,15 +26,18 @@ import (
"github.com/dapr/dapr/pkg/metrics"
"github.com/dapr/dapr/pkg/modes"
"github.com/dapr/dapr/pkg/runtime"
"github.com/dapr/dapr/pkg/security/consts"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
)
type Options struct {
AppID string
ComponentsPath string
ControlPlaneAddress string
ControlPlaneTrustDomain string
ControlPlaneNamespace string
SentryAddress string
TrustAnchors []byte
AllowedOrigins string
EnableProfiling bool
AppMaxConcurrency int
@ -90,6 +94,8 @@ func New(args []string) *Options {
flag.StringVar(&opts.AppID, "app-id", "", "A unique ID for Dapr. Used for Service Discovery and state")
flag.StringVar(&opts.ControlPlaneAddress, "control-plane-address", "", "Address for a Dapr control plane")
flag.StringVar(&opts.SentryAddress, "sentry-address", "", "Address for the Sentry CA service")
flag.StringVar(&opts.ControlPlaneTrustDomain, "control-plane-trust-domain", "localhost", "Trust domain of the Dapr control plane")
flag.StringVar(&opts.ControlPlaneNamespace, "control-plane-namespace", "default", "Namespace of the Dapr control plane")
flag.StringVar(&opts.PlacementServiceHostAddr, "placement-host-address", "", "Addresses for Dapr Actor Placement servers")
flag.StringVar(&opts.AllowedOrigins, "allowed-origins", cors.DefaultAllowedOrigins, "Allowed HTTP origins")
flag.BoolVar(&opts.EnableProfiling, "enable-profiling", false, "Enable profiling")
@ -121,23 +127,42 @@ func New(args []string) *Options {
// Ignore errors; CommandLine is set for ExitOnError.
flag.CommandLine.Parse(args)
opts.TrustAnchors = []byte(os.Getenv(consts.TrustAnchorsEnvVar))
// flag.Parse() will always set a value to "enableAPILogging", and it will be false whether it's explicitly set to false or unset
// For this flag, we need the third state (unset) so we need to do a bit more work here to check if it's unset, then mark "enableAPILogging" as nil
// It's not the prettiest approach, but…
if !*opts.EnableAPILogging {
if !isFlagPassed("enable-api-logging") {
opts.EnableAPILogging = nil
for _, v := range args {
if strings.HasPrefix(v, "--enable-api-logging") || strings.HasPrefix(v, "-enable-api-logging") {
// This means that enable-api-logging was explicitly set to false
opts.EnableAPILogging = ptr.Of(false)
break
}
}
if !isFlagPassed("control-plane-namespace") {
ns, ok := os.LookupEnv(consts.ControlPlaneNamespaceEnvVar)
if ok {
opts.ControlPlaneNamespace = ns
}
}
if !isFlagPassed("control-plane-trust-domain") {
td, ok := os.LookupEnv(consts.ControlPlaneTrustDomainEnvVar)
if ok {
opts.ControlPlaneTrustDomain = td
}
}
return &opts
}
func isFlagPassed(name string) bool {
found := false
flag.Visit(func(f *flag.Flag) {
if f.Name == name {
found = true
}
})
return found
}
// Flag type. Allows passing a flag multiple times to get a slice of strings.
// It implements the flag.Value interface.
type stringSliceFlag []string

View File

@ -42,3 +42,49 @@ func TestStandaloneGlobalConfig(t *testing.T) {
assert.EqualValues(t, string(modes.StandaloneMode), opts.Mode)
assert.Equal(t, []string{"../../../pkg/config/testdata/metric_disabled.yaml"}, []string(opts.Config))
}
func TestControlPlaneEnvVar(t *testing.T) {
t.Run("should default CLI flags if not defined", func(t *testing.T) {
// reset CommandLine to avoid conflicts from other tests
flag.CommandLine = flag.NewFlagSet("runtime-flag-test-cmd", flag.ExitOnError)
opts := New([]string{})
assert.EqualValues(t, "localhost", opts.ControlPlaneTrustDomain)
assert.EqualValues(t, "default", opts.ControlPlaneNamespace)
})
t.Run("should use CLI flags if defined", func(t *testing.T) {
// reset CommandLine to avoid conflicts from other tests
flag.CommandLine = flag.NewFlagSet("runtime-flag-test-cmd", flag.ExitOnError)
opts := New([]string{"--control-plane-namespace", "flag-namespace", "--control-plane-trust-domain", "flag-trust-domain"})
assert.EqualValues(t, "flag-trust-domain", opts.ControlPlaneTrustDomain)
assert.EqualValues(t, "flag-namespace", opts.ControlPlaneNamespace)
})
t.Run("should use env vars if flags were not defined", func(t *testing.T) {
// reset CommandLine to avoid conflicts from other tests
flag.CommandLine = flag.NewFlagSet("runtime-flag-test-cmd", flag.ExitOnError)
t.Setenv("DAPR_CONTROLPLANE_NAMESPACE", "env-namespace")
t.Setenv("DAPR_CONTROLPLANE_TRUST_DOMAIN", "env-trust-domain")
opts := New([]string{})
assert.EqualValues(t, "env-trust-domain", opts.ControlPlaneTrustDomain)
assert.EqualValues(t, "env-namespace", opts.ControlPlaneNamespace)
})
t.Run("should priorities CLI flags if both flags and env vars are defined", func(t *testing.T) {
// reset CommandLine to avoid conflicts from other tests
flag.CommandLine = flag.NewFlagSet("runtime-flag-test-cmd", flag.ExitOnError)
t.Setenv("DAPR_CONTROLPLANE_NAMESPACE", "env-namespace")
t.Setenv("DAPR_CONTROLPLANE_TRUST_DOMAIN", "env-trust-domain")
opts := New([]string{"--control-plane-namespace", "flag-namespace", "--control-plane-trust-domain", "flag-trust-domain"})
assert.EqualValues(t, "flag-trust-domain", opts.ControlPlaneTrustDomain)
assert.EqualValues(t, "flag-namespace", opts.ControlPlaneNamespace)
})
}

View File

@ -15,15 +15,21 @@ package main
import (
"context"
"encoding/base64"
"fmt"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"github.com/dapr/dapr/cmd/injector/options"
"github.com/dapr/dapr/pkg/buildinfo"
scheme "github.com/dapr/dapr/pkg/client/clientset/versioned"
"github.com/dapr/dapr/pkg/concurrency"
"github.com/dapr/dapr/pkg/health"
"github.com/dapr/dapr/pkg/injector/sentry"
"github.com/dapr/dapr/pkg/injector/service"
"github.com/dapr/dapr/pkg/metrics"
"github.com/dapr/dapr/pkg/security"
"github.com/dapr/dapr/pkg/signals"
"github.com/dapr/dapr/utils"
"github.com/dapr/kit/logger"
@ -75,15 +81,55 @@ func main() {
log.Fatalf("Failed to get authentication uids from services accounts: %s", err)
}
inj, err := service.NewInjector(uids, cfg, daprClient, kubeClient)
secProvider, err := security.New(ctx, security.Options{
SentryAddress: cfg.SentryAddress,
ControlPlaneTrustDomain: cfg.ControlPlaneTrustDomain,
ControlPlaneNamespace: security.CurrentNamespace(),
TrustAnchorsFile: cfg.TrustAnchorsFile,
AppID: "dapr-injector",
MTLSEnabled: true,
})
if err != nil {
log.Fatal(err)
}
inj, err := service.NewInjector(service.Options{
AuthUIDs: uids,
Config: cfg,
DaprClient: daprClient,
KubeClient: kubeClient,
ControlPlaneNamespace: security.CurrentNamespace(),
ControlPlaneTrustDomain: cfg.ControlPlaneTrustDomain,
})
if err != nil {
log.Fatalf("Error creating injector: %v", err)
}
healthzServer := health.NewServer(log)
caBundleCh := make(chan []byte)
mngr := concurrency.NewRunnerManager(
inj.Run,
metricsExporter.Run,
secProvider.Run,
func(ctx context.Context) error {
sec, rerr := secProvider.Handler(ctx)
if rerr != nil {
return rerr
}
sentryID, rerr := security.SentryID(sec.ControlPlaneTrustDomain(), security.CurrentNamespace())
if err != nil {
return rerr
}
requester := sentry.New(sentry.Options{
SentryAddress: cfg.SentryAddress,
SentryID: sentryID,
Security: sec,
})
return inj.Run(ctx,
sec.TLSServerConfigNoClientAuth(),
requester.RequestCertificateFromSentry,
sec.CurrentTrustAnchors,
)
},
func(ctx context.Context) error {
readyErr := inj.Ready(ctx)
if readyErr != nil {
@ -100,6 +146,48 @@ func main() {
}
return nil
},
func(ctx context.Context) error {
sec, rErr := secProvider.Handler(ctx)
if rErr != nil {
return rErr
}
sec.WatchTrustAnchors(ctx, caBundleCh)
return nil
},
// Watch for changes to the trust anchors and update the webhook
// configuration on events.
func(ctx context.Context) error {
sec, rerr := secProvider.Handler(ctx)
if rerr != nil {
return rerr
}
caBundle, rErr := sec.CurrentTrustAnchors()
if rErr != nil {
return rErr
}
// Patch the mutating webhook configuration with the current trust
// anchors.
// Re-patch every time the trust anchors change.
for {
_, rErr = kubeClient.AdmissionregistrationV1().MutatingWebhookConfigurations().Patch(ctx,
"dapr-sidecar-injector",
types.JSONPatchType,
[]byte(`[{"op":"replace","path":"/webhooks/0/clientConfig/caBundle","value":"`+base64.StdEncoding.EncodeToString(caBundle)+`"}]`),
metav1.PatchOptions{},
)
if rErr != nil {
return rErr
}
select {
case caBundle = <-caBundleCh:
case <-ctx.Done():
return nil
}
}
},
)
err = mngr.Run(ctx)

View File

@ -19,7 +19,6 @@ import (
"k8s.io/client-go/util/homedir"
"github.com/dapr/dapr/pkg/credentials"
"github.com/dapr/dapr/pkg/metrics"
"github.com/dapr/kit/logger"
)
@ -36,10 +35,6 @@ func New() *Options {
flag.IntVar(&opts.HealthzPort, "healthz-port", 8080, "The port used for health checks")
flag.StringVar(&credentials.RootCertFilename, "issuer-ca-secret-key", credentials.RootCertFilename, "Certificate Authority certificate secret key")
flag.StringVar(&credentials.IssuerCertFilename, "issuer-certificate-secret-key", credentials.IssuerCertFilename, "Issuer certificate secret key")
flag.StringVar(&credentials.IssuerKeyFilename, "issuer-key-secret-key", credentials.IssuerKeyFilename, "Issuer private key secret key")
if home := homedir.HomeDir(); home != "" {
flag.StringVar(&opts.Kubeconfig, "kubeconfig", filepath.Join(home, ".kube", "config"), "(optional) absolute path to the kubeconfig file")
} else {

View File

@ -88,7 +88,7 @@ func main() {
return raftServer.StartRaft(ctx, sec, nil)
},
metricsExporter.Run,
secProvider.Start,
secProvider.Run,
apiServer.MonitorLeadership,
func(ctx context.Context) error {
var metadataOptions []health.RouterOptions

View File

@ -22,7 +22,6 @@ import (
"github.com/dapr/dapr/cmd/sentry/options"
"github.com/dapr/dapr/pkg/buildinfo"
"github.com/dapr/dapr/pkg/concurrency"
"github.com/dapr/dapr/pkg/credentials"
"github.com/dapr/dapr/pkg/health"
"github.com/dapr/dapr/pkg/metrics"
"github.com/dapr/dapr/pkg/sentry"
@ -65,9 +64,9 @@ func main() {
log.Fatal(err)
}
issuerCertPath := filepath.Join(opts.IssuerCredentialsPath, credentials.IssuerCertFilename)
issuerKeyPath := filepath.Join(opts.IssuerCredentialsPath, credentials.IssuerKeyFilename)
rootCertPath := filepath.Join(opts.IssuerCredentialsPath, credentials.RootCertFilename)
issuerCertPath := filepath.Join(opts.IssuerCredentialsPath, opts.IssuerCertFilename)
issuerKeyPath := filepath.Join(opts.IssuerCredentialsPath, opts.IssuerKeyFilename)
rootCertPath := filepath.Join(opts.IssuerCredentialsPath, opts.RootCAFilename)
cfg, err := config.FromConfigName(opts.ConfigName)
if err != nil {

View File

@ -19,7 +19,6 @@ import (
"k8s.io/client-go/util/homedir"
"github.com/dapr/dapr/pkg/credentials"
"github.com/dapr/dapr/pkg/metrics"
"github.com/dapr/dapr/pkg/sentry/config"
"github.com/dapr/kit/logger"
@ -43,6 +42,10 @@ type Options struct {
Kubeconfig string
Logger logger.Options
Metrics *metrics.Options
RootCAFilename string
IssuerCertFilename string
IssuerKeyFilename string
}
func New() *Options {
@ -50,9 +53,9 @@ func New() *Options {
flag.StringVar(&opts.ConfigName, "config", defaultDaprSystemConfigName, "Path to config file, or name of a configuration object")
flag.StringVar(&opts.IssuerCredentialsPath, "issuer-credentials", defaultCredentialsPath, "Path to the credentials directory holding the issuer data")
flag.StringVar(&credentials.RootCertFilename, "issuer-ca-filename", credentials.RootCertFilename, "Certificate Authority certificate filename")
flag.StringVar(&credentials.IssuerCertFilename, "issuer-certificate-filename", credentials.IssuerCertFilename, "Issuer certificate filename")
flag.StringVar(&credentials.IssuerKeyFilename, "issuer-key-filename", credentials.IssuerKeyFilename, "Issuer private key filename")
flag.StringVar(&opts.RootCAFilename, "issuer-ca-filename", config.DefaultRootCertFilename, "Certificate Authority certificate filename")
flag.StringVar(&opts.IssuerCertFilename, "issuer-certificate-filename", config.DefaultIssuerCertFilename, "Issuer certificate filename")
flag.StringVar(&opts.IssuerKeyFilename, "issuer-key-filename", config.DefaultIssuerKeyFilename, "Issuer private key filename")
flag.StringVar(&opts.TrustDomain, "trust-domain", "localhost", "The CA trust domain")
flag.StringVar(&opts.TokenAudience, "token-audience", "", "DEPRECATED, flag has no effect.")
flag.IntVar(&opts.Port, "port", config.DefaultPort, "The port for the sentry server to listen on")

View File

@ -26,4 +26,4 @@ We reviewed parity of state store APIs .
## Consequences
No changes needed to bring the parity among state store APIs. APIs continue to remain same as current 0.10.0 version.
No changes needed to bring the parity among state store APIs. APIs continue to remain same as current 0.10.0 version.

View File

@ -16,24 +16,17 @@ package acl
import (
"context"
"encoding/asn1"
"errors"
"fmt"
"strings"
"github.com/PuerkitoBio/purell"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"github.com/dapr/kit/logger"
"github.com/dapr/dapr/pkg/config"
diag "github.com/dapr/dapr/pkg/diagnostics"
commonv1pb "github.com/dapr/dapr/pkg/proto/common/v1"
)
const (
SpiffeIDPrefix = "spiffe://"
"github.com/dapr/dapr/pkg/security/spiffe"
)
var log = logger.NewLogger("dapr.acl")
@ -148,92 +141,6 @@ func ParseAccessControlSpec(accessControlSpec *config.AccessControlSpec, isHTTP
return &accessControlList, nil
}
// GetAndParseSpiffeID retrieves the SPIFFE Id from the cert and parses it.
func GetAndParseSpiffeID(ctx context.Context) (*SpiffeID, error) {
spiffeID, err := getSpiffeID(ctx)
if err != nil {
return nil, err
}
id, err := parseSpiffeID(spiffeID)
return id, err
}
func parseSpiffeID(spiffeID string) (*SpiffeID, error) {
if spiffeID == "" {
return nil, errors.New("input spiffe id string is empty")
}
if !strings.HasPrefix(spiffeID, SpiffeIDPrefix) {
return nil, fmt.Errorf("input spiffe id: %s is invalid", spiffeID)
}
// The SPIFFE Id will be of the format: spiffe://<trust-domain/ns/<namespace>/<app-id>
parts := strings.Split(spiffeID, "/")
if len(parts) < 6 {
return nil, fmt.Errorf("input spiffe id: %s is invalid", spiffeID)
}
var id SpiffeID
id.TrustDomain = parts[2]
id.Namespace = parts[4]
id.AppID = parts[5]
return &id, nil
}
func getSpiffeID(ctx context.Context) (string, error) {
var spiffeID string
peer, ok := peer.FromContext(ctx)
if ok {
if peer == nil || peer.AuthInfo == nil {
return "", errors.New("unable to retrieve peer auth info")
}
tlsInfo := peer.AuthInfo.(credentials.TLSInfo)
// https://www.ietf.org/rfc/rfc3280.txt
oid := asn1.ObjectIdentifier{2, 5, 29, 17}
for _, crt := range tlsInfo.State.PeerCertificates {
for _, ext := range crt.Extensions {
if ext.Id.Equal(oid) {
var sequence asn1.RawValue
if rest, err := asn1.Unmarshal(ext.Value, &sequence); err != nil {
log.Debug(err)
continue
} else if len(rest) != 0 {
log.Debug("the SAN extension is incorrectly encoded")
continue
}
if !sequence.IsCompound || sequence.Tag != asn1.TagSequence || sequence.Class != asn1.ClassUniversal {
log.Debug("the SAN extension is incorrectly encoded")
continue
}
for bytes := sequence.Bytes; len(bytes) > 0; {
var rawValue asn1.RawValue
var err error
bytes, err = asn1.Unmarshal(bytes, &rawValue)
if err != nil {
return "", err
}
spiffeID = string(rawValue.Bytes)
if strings.HasPrefix(spiffeID, SpiffeIDPrefix) {
return spiffeID, nil
}
}
}
}
}
}
return "", nil
}
func normalizeOperation(operation string) (string, error) {
s, err := purell.NormalizeURLString(operation, purell.FlagsUsuallySafeGreedy|purell.FlagRemoveDuplicateSlashes)
if err != nil {
@ -244,16 +151,15 @@ func normalizeOperation(operation string) (string, error) {
func ApplyAccessControlPolicies(ctx context.Context, operation string, httpVerb commonv1pb.HTTPExtension_Verb, isHTTP bool, acl *config.AccessControlList) (bool, string) {
// Apply access control list filter
spiffeID, err := GetAndParseSpiffeID(ctx)
spiffeID, ok, err := spiffe.FromGRPCContext(ctx)
if err != nil {
// Apply the default action
log.Debugf("Error while reading spiffe id from client cert: %v. applying default global policy action", err)
log.Debugf("failed to get SPIFFE ID from gRPC connection context: %v", err)
return false, err.Error()
}
var appID, trustDomain, namespace string
if spiffeID != nil {
appID = spiffeID.AppID
namespace = spiffeID.Namespace
trustDomain = spiffeID.TrustDomain
if !ok {
// Apply the default action
log.Debugf("Error while reading spiffe id from client cert. applying default global policy action")
}
operation, err = normalizeOperation(operation)
@ -265,37 +171,37 @@ func ApplyAccessControlPolicies(ctx context.Context, operation string, httpVerb
return false, errMessage
}
action, actionPolicy := IsOperationAllowedByAccessControlPolicy(spiffeID, appID, operation, httpVerb, isHTTP, acl)
emitACLMetrics(actionPolicy, appID, trustDomain, namespace, operation, httpVerb.String(), action)
action, actionPolicy := isOperationAllowedByAccessControlPolicy(spiffeID, operation, httpVerb, isHTTP, acl)
emitACLMetrics(spiffeID, actionPolicy, operation, httpVerb.String(), action)
if !action {
errMessage = fmt.Sprintf("access control policy has denied access to appid: %s operation: %s verb: %s", appID, operation, httpVerb)
log.Debugf(errMessage)
errMessage = fmt.Sprintf("access control policy has denied access to id: %s operation: %s verb: %s", spiffeID.URL(), operation, httpVerb)
log.Debug(errMessage)
}
return action, errMessage
}
func emitACLMetrics(actionPolicy, appID, trustDomain, namespace, operation, verb string, action bool) {
func emitACLMetrics(spiffeID *spiffe.Parsed, actionPolicy, operation, verb string, action bool) {
if action {
switch actionPolicy {
case config.ActionPolicyApp:
diag.DefaultMonitoring.RequestAllowedByAppAction(appID, trustDomain, namespace, operation, verb, action)
diag.DefaultMonitoring.RequestAllowedByAppAction(spiffeID, operation, verb, action)
case config.ActionPolicyGlobal:
diag.DefaultMonitoring.RequestAllowedByGlobalAction(appID, trustDomain, namespace, operation, verb, action)
diag.DefaultMonitoring.RequestAllowedByGlobalAction(spiffeID, operation, verb, action)
}
} else {
switch actionPolicy {
case config.ActionPolicyApp:
diag.DefaultMonitoring.RequestBlockedByAppAction(appID, trustDomain, namespace, operation, verb, action)
diag.DefaultMonitoring.RequestBlockedByAppAction(spiffeID, operation, verb, action)
case config.ActionPolicyGlobal:
diag.DefaultMonitoring.RequestBlockedByGlobalAction(appID, trustDomain, namespace, operation, verb, action)
diag.DefaultMonitoring.RequestBlockedByGlobalAction(spiffeID, operation, verb, action)
}
}
}
// IsOperationAllowedByAccessControlPolicy determines if access control policies allow the operation on the target app.
func IsOperationAllowedByAccessControlPolicy(spiffeID *SpiffeID, srcAppID string, inputOperation string, httpVerb commonv1pb.HTTPExtension_Verb, isHTTP bool, accessControlList *config.AccessControlList) (bool, string) {
// isOperationAllowedByAccessControlPolicy determines if access control policies allow the operation on the target app.
func isOperationAllowedByAccessControlPolicy(spiffeID *spiffe.Parsed, inputOperation string, httpVerb commonv1pb.HTTPExtension_Verb, isHTTP bool, accessControlList *config.AccessControlList) (bool, string) {
if accessControlList == nil {
// No access control list is provided. Do nothing
return isActionAllowed(config.AllowAccess), ""
@ -304,18 +210,13 @@ func IsOperationAllowedByAccessControlPolicy(spiffeID *SpiffeID, srcAppID string
action := accessControlList.DefaultAction
actionPolicy := config.ActionPolicyGlobal
if srcAppID == "" {
// Did not receive the src app id correctly
return isActionAllowed(action), actionPolicy
}
if spiffeID == nil {
// Could not retrieve spiffe id or it is invalid. Apply global default action
return isActionAllowed(action), actionPolicy
}
// Look up the src app id in the in-memory table. The key is appID||namespace
key := getKeyForAppID(srcAppID, spiffeID.Namespace)
key := getKeyForAppID(spiffeID.AppID(), spiffeID.Namespace())
appPolicy, found := accessControlList.PolicySpec[key]
if !found {
@ -324,12 +225,12 @@ func IsOperationAllowedByAccessControlPolicy(spiffeID *SpiffeID, srcAppID string
}
// Match trust domain
if appPolicy.TrustDomain != spiffeID.TrustDomain {
if appPolicy.TrustDomain != spiffeID.TrustDomain().String() {
return isActionAllowed(action), actionPolicy
}
// Match namespace
if appPolicy.Namespace != spiffeID.Namespace {
if appPolicy.Namespace != spiffeID.Namespace() {
return isActionAllowed(action), actionPolicy
}
@ -386,13 +287,5 @@ func isActionAllowed(action string) bool {
}
func getKeyForAppID(appID, namespace string) string {
key := appID + "||" + namespace
return key
}
// SpiffeID represents the separated fields in a spiffe id.
type SpiffeID struct {
TrustDomain string
Namespace string
AppID string
return appID + "||" + namespace
}

View File

@ -17,11 +17,13 @@ package acl
import (
"testing"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dapr/dapr/pkg/config"
"github.com/dapr/dapr/pkg/proto/common/v1"
"github.com/dapr/dapr/pkg/security/spiffe"
)
const (
@ -376,44 +378,14 @@ func TestParseAccessControlSpec(t *testing.T) {
})
}
func TestSpiffeID(t *testing.T) {
t.Run("test parse spiffe id", func(t *testing.T) {
spiffeID := "spiffe://mydomain/ns/mynamespace/myappid"
id, err := parseSpiffeID(spiffeID)
assert.Equal(t, "mydomain", id.TrustDomain)
assert.Equal(t, "mynamespace", id.Namespace)
assert.Equal(t, "myappid", id.AppID)
require.NoError(t, err)
})
t.Run("test parse invalid spiffe id", func(t *testing.T) {
spiffeID := "abcd"
_, err := parseSpiffeID(spiffeID)
require.Error(t, err)
})
t.Run("test parse spiffe id with not all fields", func(t *testing.T) {
spiffeID := "spiffe://mydomain/ns/myappid"
_, err := parseSpiffeID(spiffeID)
require.Error(t, err)
})
t.Run("test empty spiffe id", func(t *testing.T) {
spiffeID := ""
_, err := parseSpiffeID(spiffeID)
require.Error(t, err)
})
}
func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
func Test_isOperationAllowedByAccessControlPolicy(t *testing.T) {
td := spiffeid.RequireTrustDomainFromString("public")
privateTD := spiffeid.RequireTrustDomainFromString("private")
t.Run("test when no acl specified", func(t *testing.T) {
srcAppID := app1
spiffeID := SpiffeID{
TrustDomain: "public",
Namespace: "ns1",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op1", common.HTTPExtension_POST, true, nil)
spiffeID, err := spiffe.FromStrings(td, "ns1", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op1", common.HTTPExtension_POST, true, nil)
// Action = Allow the operation since no ACL is defined
assert.True(t, isAllowed)
})
@ -421,12 +393,9 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when no matching app in acl found", func(t *testing.T) {
srcAppID := "appX"
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "public",
Namespace: "ns1",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op1", common.HTTPExtension_POST, true, accessControlList)
spiffeID, err := spiffe.FromStrings(td, "ns1", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op1", common.HTTPExtension_POST, true, accessControlList)
// Action = Default global action
assert.False(t, isAllowed)
})
@ -434,12 +403,10 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when trust domain does not match", func(t *testing.T) {
srcAppID := app1
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "private",
Namespace: "ns1",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op1", common.HTTPExtension_POST, true, accessControlList)
spiffeID, err := spiffe.FromStrings(privateTD, "ns1", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op1", common.HTTPExtension_POST, true, accessControlList)
// Action = Ignore policy and apply global default action
assert.False(t, isAllowed)
})
@ -447,28 +414,23 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when namespace does not match", func(t *testing.T) {
srcAppID := app1
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "public",
Namespace: "abcd",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op1", common.HTTPExtension_POST, true, accessControlList)
spiffeID, err := spiffe.FromStrings(td, "abcd", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op1", common.HTTPExtension_POST, true, accessControlList)
// Action = Ignore policy and apply global default action
assert.False(t, isAllowed)
})
t.Run("test when spiffe id is nil", func(t *testing.T) {
srcAppID := app1
accessControlList, _ := initializeAccessControlList(true)
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(nil, srcAppID, "op1", common.HTTPExtension_POST, true, accessControlList)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(nil, "op1", common.HTTPExtension_POST, true, accessControlList)
// Action = Default global action
assert.False(t, isAllowed)
})
t.Run("test when src app id is empty", func(t *testing.T) {
srcAppID := ""
accessControlList, _ := initializeAccessControlList(true)
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(nil, srcAppID, "op1", common.HTTPExtension_POST, true, accessControlList)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(nil, "op1", common.HTTPExtension_POST, true, accessControlList)
// Action = Default global action
assert.False(t, isAllowed)
})
@ -476,12 +438,9 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when operation is not found in the policy spec", func(t *testing.T) {
srcAppID := app1
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "public",
Namespace: "ns1",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "opX", common.HTTPExtension_POST, true, accessControlList)
spiffeID, err := spiffe.FromStrings(td, "ns1", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "opX", common.HTTPExtension_POST, true, accessControlList)
// Action = Ignore policy and apply default action for app
assert.True(t, isAllowed)
})
@ -489,12 +448,9 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test http case-sensitivity when matching operation post fix", func(t *testing.T) {
srcAppID := app1
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "public",
Namespace: "ns1",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "Op2", common.HTTPExtension_POST, true, accessControlList)
spiffeID, err := spiffe.FromStrings(td, "ns1", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "Op2", common.HTTPExtension_POST, true, accessControlList)
// Action = Ignore policy and apply default action for app
assert.False(t, isAllowed)
})
@ -502,12 +458,11 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when http verb is not found", func(t *testing.T) {
srcAppID := app2
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op4", common.HTTPExtension_PUT, true, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op4", common.HTTPExtension_PUT, true, accessControlList)
// Action = Default action for the specific app
assert.False(t, isAllowed)
})
@ -515,12 +470,10 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when default action for app is not specified and no matching http verb found", func(t *testing.T) {
srcAppID := app3
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns1",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op5", common.HTTPExtension_PUT, true, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns1", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op5", common.HTTPExtension_PUT, true, accessControlList)
// Action = Global Default action
assert.False(t, isAllowed)
})
@ -528,12 +481,9 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when http verb matches *", func(t *testing.T) {
srcAppID := app1
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "public",
Namespace: "ns1",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op2", common.HTTPExtension_PUT, true, accessControlList)
spiffeID, err := spiffe.FromStrings(td, "ns1", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op2", common.HTTPExtension_PUT, true, accessControlList)
// Action = Default action for the specific verb
assert.False(t, isAllowed)
})
@ -541,12 +491,10 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when http verb matches a specific verb", func(t *testing.T) {
srcAppID := app2
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op4", common.HTTPExtension_POST, true, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op4", common.HTTPExtension_POST, true, accessControlList)
// Action = Default action for the specific verb
assert.True(t, isAllowed)
})
@ -554,12 +502,10 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when operation is invoked with /", func(t *testing.T) {
srcAppID := app2
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "/op4", common.HTTPExtension_POST, true, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "/op4", common.HTTPExtension_POST, true, accessControlList)
// Action = Default action for the specific verb
assert.True(t, isAllowed)
})
@ -567,12 +513,10 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when http verb is not specified", func(t *testing.T) {
srcAppID := app2
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op4", common.HTTPExtension_NONE, true, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op4", common.HTTPExtension_NONE, true, accessControlList)
// Action = Default action for the app
assert.False(t, isAllowed)
})
@ -580,12 +524,10 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when matching operation post fix is specified in policy spec", func(t *testing.T) {
srcAppID := app2
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "/op3/a", common.HTTPExtension_PUT, true, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "/op3/a", common.HTTPExtension_PUT, true, accessControlList)
// Action = Default action for the specific verb
assert.True(t, isAllowed)
})
@ -593,12 +535,10 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test grpc case-sensitivity when matching operation post fix", func(t *testing.T) {
srcAppID := app2
accessControlList, _ := initializeAccessControlList(false)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "/OP4", common.HTTPExtension_NONE, false, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "/OP4", common.HTTPExtension_NONE, false, accessControlList)
// Action = Default action for the specific verb
assert.False(t, isAllowed)
})
@ -606,12 +546,10 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when non-matching operation post fix is specified in policy spec", func(t *testing.T) {
srcAppID := app2
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "/op3/b/b", common.HTTPExtension_PUT, true, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "/op3/b/b", common.HTTPExtension_PUT, true, accessControlList)
// Action = Default action for the app
assert.False(t, isAllowed)
})
@ -619,12 +557,10 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test when non-matching operation post fix is specified in policy spec", func(t *testing.T) {
srcAppID := app2
accessControlList, _ := initializeAccessControlList(true)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "/op3/a/b", common.HTTPExtension_PUT, true, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "/op3/a/b", common.HTTPExtension_PUT, true, accessControlList)
// Action = Default action for the app
assert.True(t, isAllowed)
})
@ -632,12 +568,10 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("test with grpc invocation", func(t *testing.T) {
srcAppID := app2
accessControlList, _ := initializeAccessControlList(false)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op4", common.HTTPExtension_NONE, false, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op4", common.HTTPExtension_NONE, false, accessControlList)
// Action = Default action for the app
assert.True(t, isAllowed)
})
@ -645,12 +579,10 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("when testing grpc calls, acl is not configured with http verb", func(t *testing.T) {
srcAppID := app4
accessControlList, _ := initializeAccessControlList(false)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op6", common.HTTPExtension_NONE, false, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op6", common.HTTPExtension_NONE, false, accessControlList)
// Action = Default action for the app
assert.True(t, isAllowed)
})
@ -658,42 +590,36 @@ func TestIsOperationAllowedByAccessControlPolicy(t *testing.T) {
t.Run("when testing grpc calls, acl configured with wildcard * for full matching", func(t *testing.T) {
srcAppID := app4
accessControlList, _ := initializeAccessControlList(false)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op7/a/b/c", common.HTTPExtension_NONE, false, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op7/a/b/c", common.HTTPExtension_NONE, false, accessControlList)
assert.True(t, isAllowed)
isAllowed, _ = IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op7/a/b/c/d", common.HTTPExtension_NONE, false, accessControlList)
isAllowed, _ = isOperationAllowedByAccessControlPolicy(spiffeID, "op7/a/b/c/d", common.HTTPExtension_NONE, false, accessControlList)
assert.False(t, isAllowed)
isAllowed, _ = IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op7/a/b/c/f", common.HTTPExtension_NONE, false, accessControlList)
isAllowed, _ = isOperationAllowedByAccessControlPolicy(spiffeID, "op7/a/b/c/f", common.HTTPExtension_NONE, false, accessControlList)
assert.True(t, isAllowed)
})
t.Run("when testing grpc calls, acl is configured with wildcards", func(t *testing.T) {
srcAppID := app4
accessControlList, _ := initializeAccessControlList(false)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op7/a/bc", common.HTTPExtension_NONE, false, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op7/a/bc", common.HTTPExtension_NONE, false, accessControlList)
assert.True(t, isAllowed)
})
t.Run("when testing grpc calls, acl configured with wildcard ** for full matching", func(t *testing.T) {
srcAppID := app4
accessControlList, _ := initializeAccessControlList(false)
spiffeID := SpiffeID{
TrustDomain: "domain1",
Namespace: "ns2",
AppID: srcAppID,
}
isAllowed, _ := IsOperationAllowedByAccessControlPolicy(&spiffeID, srcAppID, "op7/c/d/e", common.HTTPExtension_NONE, false, accessControlList)
domain1TD := spiffeid.RequireTrustDomainFromString("domain1")
spiffeID, err := spiffe.FromStrings(domain1TD, "ns2", srcAppID)
assert.NoError(t, err)
isAllowed, _ := isOperationAllowedByAccessControlPolicy(spiffeID, "op7/c/d/e", common.HTTPExtension_NONE, false, accessControlList)
assert.True(t, isAllowed)
})
}

View File

@ -40,7 +40,6 @@ import (
"github.com/dapr/dapr/pkg/actors/timers"
"github.com/dapr/dapr/pkg/channel"
configuration "github.com/dapr/dapr/pkg/config"
daprCredentials "github.com/dapr/dapr/pkg/credentials"
diag "github.com/dapr/dapr/pkg/diagnostics"
diagUtils "github.com/dapr/dapr/pkg/diagnostics/utils"
"github.com/dapr/dapr/pkg/health"
@ -52,6 +51,7 @@ import (
"github.com/dapr/dapr/pkg/resiliency"
"github.com/dapr/dapr/pkg/retry"
"github.com/dapr/dapr/pkg/runtime/compstore"
"github.com/dapr/dapr/pkg/security"
"github.com/dapr/kit/logger"
)
@ -120,7 +120,6 @@ type actorsRuntime struct {
actorsReminders internal.RemindersProvider
actorsTable *sync.Map
appHealthy *atomic.Bool
certChain *daprCredentials.CertChain
tracingSpec configuration.TracingSpec
resiliency resiliency.Provider
storeName string
@ -128,6 +127,7 @@ type actorsRuntime struct {
clock clock.WithTicker
internalActors map[string]InternalActor
internalActorChannel *internalActorChannel
sec security.Handler
wg sync.WaitGroup
closed atomic.Bool
closeCh chan struct{}
@ -141,11 +141,11 @@ type ActorsOpts struct {
AppChannel channel.AppChannel
GRPCConnectionFn GRPCConnectionFn
Config Config
CertChain *daprCredentials.CertChain
TracingSpec configuration.TracingSpec
Resiliency resiliency.Provider
StateStoreName string
CompStore *compstore.ComponentStore
Security security.Handler
// TODO: @joshvanl Remove in Dapr 1.12 when ActorStateTTL is finalized.
StateTTLEnabled bool
@ -174,7 +174,6 @@ func newActorsWithClock(opts ActorsOpts, clock clock.WithTicker) ActorRuntime {
actorsConfig: opts.Config,
timers: timers.NewTimersProvider(clock),
actorsReminders: remindersProvider,
certChain: opts.CertChain,
tracingSpec: opts.TracingSpec,
resiliency: opts.Resiliency,
storeName: opts.StateStoreName,
@ -185,6 +184,7 @@ func newActorsWithClock(opts ActorsOpts, clock clock.WithTicker) ActorRuntime {
internalActors: map[string]InternalActor{},
internalActorChannel: newInternalActorChannel(),
compStore: opts.CompStore,
sec: opts.Security,
// TODO: @joshvanl Remove in Dapr 1.12 when ActorStateTTL is finalized.
stateTTLEnabled: opts.StateTTLEnabled,
@ -247,7 +247,7 @@ func (a *actorsRuntime) Init(ctx context.Context) error {
if a.placement == nil {
a.placement = placement.NewActorPlacement(placement.ActorPlacementOpts{
ServerAddrs: a.actorsConfig.Config.PlacementAddresses,
CertChain: a.certChain,
Security: a.sec,
AppID: a.actorsConfig.Config.AppID,
RuntimeHostname: hostname,
PodName: a.actorsConfig.Config.PodName,

View File

@ -72,7 +72,7 @@ func (c *internalActorChannel) Contains(actorType string) bool {
}
// GetAppConfig implements channel.AppChannel
func (c *internalActorChannel) GetAppConfig(appID string) (*config.ApplicationConfig, error) {
func (c *internalActorChannel) GetAppConfig(_ context.Context, appID string) (*config.ApplicationConfig, error) {
actorTypes := make([]string, 0, len(c.actors))
for actorType := range c.actors {
actorTypes = append(actorTypes, actorType)

View File

@ -30,6 +30,7 @@ import (
invokev1 "github.com/dapr/dapr/pkg/messaging/v1"
"github.com/dapr/dapr/pkg/resiliency"
"github.com/dapr/dapr/pkg/runtime/compstore"
"github.com/dapr/dapr/pkg/security/fake"
)
type mockInternalActor struct {
@ -117,6 +118,7 @@ func newTestActorsRuntimeWithInternalActors(internalActors map[string]InternalAc
TracingSpec: spec,
Resiliency: resiliency.New(log),
StateStoreName: "actorStore",
Security: fake.New(),
})
for actorType, actor := range internalActors {

View File

@ -17,9 +17,9 @@ import (
"context"
"sync"
v1pb "github.com/dapr/dapr/pkg/proto/placement/v1"
"google.golang.org/grpc"
v1pb "github.com/dapr/dapr/pkg/proto/placement/v1"
)
// placementClient implements the best practices when handling grpc streams

View File

@ -15,12 +15,16 @@ package placement
import (
"context"
"crypto/x509"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"github.com/dapr/dapr/pkg/security"
)
func TestConnectToServer(t *testing.T) {
@ -49,7 +53,7 @@ func TestConnectToServer(t *testing.T) {
conn, _, cleanup := newTestServer() // do not register the placement stream server
defer cleanup()
client := newPlacementClient(getGrpcOptsGetter([]string{conn}, nil))
client := newPlacementClient(getGrpcOptsGetter([]string{conn}, testSecurity(t)))
var ready sync.WaitGroup
ready.Add(1)
@ -95,7 +99,7 @@ func TestDisconnect(t *testing.T) {
conn, _, cleanup := newTestServer() // do not register the placement stream server
defer cleanup()
client := newPlacementClient(getGrpcOptsGetter([]string{conn}, nil))
client := newPlacementClient(getGrpcOptsGetter([]string{conn}, testSecurity(t)))
assert.Nil(t, client.connectToServer(context.Background(), conn))
called := false
@ -118,3 +122,22 @@ func TestDisconnect(t *testing.T) {
assert.True(t, called)
})
}
func testSecurity(t *testing.T) security.Handler {
secP, err := security.New(context.Background(), security.Options{
TrustAnchors: []byte("test"),
AppID: "test",
ControlPlaneTrustDomain: "test.example.com",
ControlPlaneNamespace: "default",
MTLSEnabled: false,
OverrideCertRequestSource: func(context.Context, []byte) ([]*x509.Certificate, error) {
return []*x509.Certificate{nil}, nil
},
})
require.NoError(t, err)
go secP.Run(context.Background())
sec, err := secP.Handler(context.Background())
require.NoError(t, err)
return sec
}

View File

@ -18,19 +18,20 @@ import (
"strings"
"sync"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"google.golang.org/grpc"
daprCredentials "github.com/dapr/dapr/pkg/credentials"
diag "github.com/dapr/dapr/pkg/diagnostics"
"github.com/dapr/dapr/pkg/runtime/security"
"github.com/dapr/dapr/pkg/security"
)
var errEstablishingTLSConn = errors.New("failed to establish TLS credentials for actor placement service")
// getGrpcOptsGetter returns a function that provides the grpc options and once defined, a cached version will be returned.
func getGrpcOptsGetter(servers []string, clientCert *daprCredentials.CertChain) func() ([]grpc.DialOption, error) {
func getGrpcOptsGetter(servers []string, sec security.Handler) func() ([]grpc.DialOption, error) {
mu := sync.RWMutex{}
var cached []grpc.DialOption
return func() ([]grpc.DialOption, error) {
mu.RLock()
if cached != nil {
@ -45,17 +46,19 @@ func getGrpcOptsGetter(servers []string, clientCert *daprCredentials.CertChain)
return cached, nil
}
opts, err := daprCredentials.GetClientOptions(clientCert, security.TLSServerName)
var opts []grpc.DialOption
placementID, err := spiffeid.FromSegments(sec.ControlPlaneTrustDomain(), "ns", sec.ControlPlaneNamespace(), "dapr-placement")
if err != nil {
log.Errorf("%s: %v", errEstablishingTLSConn, err)
return nil, errEstablishingTLSConn
}
opts = append(opts, sec.GRPCDialOptionMTLS(placementID))
if diag.DefaultGRPCMonitoring.IsEnabled() {
opts = append(
opts,
grpc.WithUnaryInterceptor(diag.DefaultGRPCMonitoring.UnaryClientInterceptor()),
grpc.WithBlock(),
)
}

View File

@ -26,10 +26,10 @@ import (
"google.golang.org/grpc/status"
"github.com/dapr/dapr/pkg/actors/internal"
daprCredentials "github.com/dapr/dapr/pkg/credentials"
diag "github.com/dapr/dapr/pkg/diagnostics"
"github.com/dapr/dapr/pkg/placement/hashing"
v1pb "github.com/dapr/dapr/pkg/proto/placement/v1"
"github.com/dapr/dapr/pkg/security"
"github.com/dapr/kit/logger"
)
@ -91,12 +91,14 @@ type actorPlacement struct {
shutdown atomic.Bool
// shutdownConnLoop is the wait group to wait until all connection loop are done
shutdownConnLoop sync.WaitGroup
// closeCh is the channel to close the placement service.
closeCh chan struct{}
}
// ActorPlacementOpts contains options for NewActorPlacement.
type ActorPlacementOpts struct {
ServerAddrs []string // Address(es) for the Placement service
CertChain *daprCredentials.CertChain
Security security.Handler
AppID string
RuntimeHostname string
PodName string
@ -115,7 +117,7 @@ func NewActorPlacement(opts ActorPlacementOpts) internal.PlacementService {
podName: opts.PodName,
serverAddr: servers,
client: newPlacementClient(getGrpcOptsGetter(servers, opts.CertChain)),
client: newPlacementClient(getGrpcOptsGetter(servers, opts.Security)),
placementTableLock: &sync.RWMutex{},
placementTables: &hashing.ConsistentHashTables{Entries: make(map[string]*hashing.Consistent)},
@ -124,6 +126,7 @@ func NewActorPlacement(opts ActorPlacementOpts) internal.PlacementService {
tableIsBlocked: &atomic.Bool{},
appHealthFn: opts.AppHealthFn,
afterTableUpdateFn: opts.AfterTableUpdateFn,
closeCh: make(chan struct{}),
}
}
@ -150,6 +153,18 @@ func (p *actorPlacement) Start(ctx context.Context) error {
return nil
}
ctx, cancel := context.WithCancel(ctx)
p.shutdownConnLoop.Add(1)
go func() {
defer p.shutdownConnLoop.Done()
select {
case <-ctx.Done():
case <-p.closeCh:
}
cancel()
}()
// establish connection loop, whenever a disconnect occurs it starts to run trying to connect to a new server.
p.shutdownConnLoop.Add(1)
go func() {
@ -252,6 +267,8 @@ func (p *actorPlacement) Close() error {
// CAS to avoid stop more than once.
if p.shutdown.CompareAndSwap(false, true) {
p.client.disconnect()
p.shutdown.Store(true)
close(p.closeCh)
}
p.shutdownConnLoop.Wait()
return nil

View File

@ -77,6 +77,7 @@ func TestPlacementStream_RoundRobin(t *testing.T) {
ActorTypes: []string{"actorOne", "actorTwo"},
AppHealthFn: appHealthFunc,
AfterTableUpdateFn: func() {},
Security: testSecurity(t),
}).(*actorPlacement)
t.Run("found leader placement in a round robin way", func(t *testing.T) {
@ -134,6 +135,7 @@ func TestAppHealthyStatus(t *testing.T) {
ActorTypes: []string{"actorOne", "actorTwo"},
AppHealthFn: appHealthFunc,
AfterTableUpdateFn: func() {},
Security: testSecurity(t),
}).(*actorPlacement)
// act
@ -166,6 +168,7 @@ func TestOnPlacementOrder(t *testing.T) {
ActorTypes: []string{"actorOne", "actorTwo"},
AppHealthFn: appHealthFunc,
AfterTableUpdateFn: tableUpdateFunc,
Security: testSecurity(t),
}).(*actorPlacement)
t.Run("lock operation", func(t *testing.T) {
@ -218,6 +221,7 @@ func TestWaitUntilPlacementTableIsReady(t *testing.T) {
ActorTypes: []string{"actorOne", "actorTwo"},
AppHealthFn: appHealthFunc,
AfterTableUpdateFn: func() {},
Security: testSecurity(t),
}).(*actorPlacement)
t.Run("already unlocked", func(t *testing.T) {
@ -295,6 +299,7 @@ func TestLookupActor(t *testing.T) {
ActorTypes: []string{"actorOne", "actorTwo"},
AppHealthFn: appHealthFunc,
AfterTableUpdateFn: func() {},
Security: testSecurity(t),
}).(*actorPlacement)
t.Run("Placementtable is unset", func(t *testing.T) {
@ -338,6 +343,7 @@ func TestConcurrentUnblockPlacements(t *testing.T) {
ActorTypes: []string{"actorOne", "actorTwo"},
AppHealthFn: appHealthFunc,
AfterTableUpdateFn: func() {},
Security: testSecurity(t),
}).(*actorPlacement)
t.Run("concurrent_unlock", func(t *testing.T) {

View File

@ -29,7 +29,7 @@ const (
// AppChannel is an abstraction over communications with user code.
type AppChannel interface {
GetAppConfig(appID string) (*config.ApplicationConfig, error)
GetAppConfig(ctx context.Context, appID string) (*config.ApplicationConfig, error)
InvokeMethod(ctx context.Context, req *invokev1.InvokeMethodRequest, appID string) (*invokev1.InvokeMethodResponse, error)
HealthProbe(ctx context.Context) (bool, error)
SetAppHealth(ah *apphealth.AppHealth)

View File

@ -66,7 +66,7 @@ func CreateLocalChannel(port, maxConcurrency int, conn *grpc.ClientConn, spec co
}
// GetAppConfig gets application config from user application.
func (g *Channel) GetAppConfig(appID string) (*config.ApplicationConfig, error) {
func (g *Channel) GetAppConfig(_ context.Context, appID string) (*config.ApplicationConfig, error) {
return nil, nil
}

View File

@ -103,7 +103,7 @@ func CreateHTTPChannel(config ChannelConfiguration) (channel.AppChannel, error)
// GetAppConfig gets application config from user application
// GET http://localhost:<app_port>/dapr/config
func (h *Channel) GetAppConfig(appID string) (*config.ApplicationConfig, error) {
func (h *Channel) GetAppConfig(ctx context.Context, appID string) (*config.ApplicationConfig, error) {
req := invokev1.NewInvokeMethodRequest(appConfigEndpoint).
WithHTTPExtension(http.MethodGet, "").
WithContentType(invokev1.JSONContentType).
@ -112,7 +112,7 @@ func (h *Channel) GetAppConfig(appID string) (*config.ApplicationConfig, error)
})
defer req.Close()
resp, err := h.InvokeMethod(context.TODO(), req, "")
resp, err := h.InvokeMethod(ctx, req, "")
if err != nil {
return nil, err
}

View File

@ -21,7 +21,7 @@ type MockAppChannel struct {
}
// GetAppConfig provides a mock function with given fields:
func (_m *MockAppChannel) GetAppConfig(_ string) (*config.ApplicationConfig, error) {
func (_m *MockAppChannel) GetAppConfig(_ context.Context, _ string) (*config.ApplicationConfig, error) {
ret := _m.Called()
var r0 *config.ApplicationConfig

View File

@ -1,42 +0,0 @@
package credentials
import (
"os"
)
var (
// RootCertFilename is the filename that holds the root certificate.
RootCertFilename = "ca.crt"
// IssuerCertFilename is the filename that holds the issuer certificate.
IssuerCertFilename = "issuer.crt"
// IssuerKeyFilename is the filename that holds the issuer key.
IssuerKeyFilename = "issuer.key"
)
// CertChain holds the certificate trust chain PEM values.
type CertChain struct {
RootCA []byte
Cert []byte
Key []byte
}
// LoadFromDisk retruns a CertChain from a given directory.
func LoadFromDisk(rootCertPath, issuerCertPath, issuerKeyPath string) (*CertChain, error) {
rootCert, err := os.ReadFile(rootCertPath)
if err != nil {
return nil, err
}
cert, err := os.ReadFile(issuerCertPath)
if err != nil {
return nil, err
}
key, err := os.ReadFile(issuerKeyPath)
if err != nil {
return nil, err
}
return &CertChain{
RootCA: rootCert,
Cert: cert,
Key: key,
}, nil
}

View File

@ -1,37 +0,0 @@
package credentials
import (
"path/filepath"
)
// TLSCredentials holds paths for credentials.
type TLSCredentials struct {
credentialsPath string
}
// NewTLSCredentials returns a new TLSCredentials.
func NewTLSCredentials(path string) TLSCredentials {
return TLSCredentials{
credentialsPath: path,
}
}
// Path returns the directory holding the TLS credentials.
func (t *TLSCredentials) Path() string {
return t.credentialsPath
}
// RootCertPath returns the file path for the root cert.
func (t *TLSCredentials) RootCertPath() string {
return filepath.Join(t.credentialsPath, RootCertFilename)
}
// CertPath returns the file path for the cert.
func (t *TLSCredentials) CertPath() string {
return filepath.Join(t.credentialsPath, IssuerCertFilename)
}
// KeyPath returns the file path for the cert key.
func (t *TLSCredentials) KeyPath() string {
return filepath.Join(t.credentialsPath, IssuerKeyFilename)
}

View File

@ -1,58 +0,0 @@
package credentials
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)
func GetServerOptions(certChain *CertChain) ([]grpc.ServerOption, error) {
opts := []grpc.ServerOption{}
if certChain == nil {
return nil, nil
}
cp := x509.NewCertPool()
cp.AppendCertsFromPEM(certChain.RootCA)
cert, err := tls.X509KeyPair(certChain.Cert, certChain.Key)
if err != nil {
return nil, err
}
config := &tls.Config{
ClientCAs: cp,
// Require cert verification
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}
opts = append(opts, grpc.Creds(credentials.NewTLS(config)))
return opts, nil
}
func GetClientOptions(certChain *CertChain, serverName string) ([]grpc.DialOption, error) {
opts := []grpc.DialOption{}
if certChain != nil {
cp := x509.NewCertPool()
ok := cp.AppendCertsFromPEM(certChain.RootCA)
if !ok {
return nil, errors.New("failed to append PEM root cert to x509 CertPool")
}
config, err := TLSConfigFromCertAndKey(certChain.Cert, certChain.Key, serverName, cp)
config.MinVersion = tls.VersionTLS12
if err != nil {
return nil, fmt.Errorf("failed to create tls config from cert and key: %w", err)
}
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(config)))
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
return opts, nil
}

View File

@ -1,54 +0,0 @@
package credentials
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestServerOptions(t *testing.T) {
t.Run("valid certs", func(t *testing.T) {
chain := &CertChain{
RootCA: []byte(TestCACert),
Cert: []byte(TestCert),
Key: []byte(TestKey),
}
opts, err := GetServerOptions(chain)
assert.Nil(t, err)
assert.Len(t, opts, 1)
})
t.Run("invalid certs", func(t *testing.T) {
chain := &CertChain{
RootCA: []byte(nil),
Cert: []byte(nil),
Key: []byte(nil),
}
_, err := GetServerOptions(chain)
assert.Error(t, err)
})
}
func TestClientOptions(t *testing.T) {
t.Run("valid certs", func(t *testing.T) {
chain := &CertChain{
RootCA: []byte(TestCACert),
Cert: []byte(TestCert),
Key: []byte(TestKey),
}
opts, err := GetClientOptions(chain, "")
assert.Nil(t, err)
assert.Len(t, opts, 1)
})
t.Run("invalid certs", func(t *testing.T) {
chain := &CertChain{
RootCA: []byte(nil),
Cert: []byte(nil),
Key: []byte(nil),
}
opts, err := GetClientOptions(chain, "")
assert.NotNil(t, err)
assert.Nil(t, opts)
})
}

View File

@ -1,24 +0,0 @@
package credentials
import (
"crypto/tls"
"crypto/x509"
)
// TLSConfigFromCertAndKey return a tls.config object from valid cert/key pair in PEM format.
func TLSConfigFromCertAndKey(certPem, keyPem []byte, serverName string, rootCA *x509.CertPool) (*tls.Config, error) {
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
return nil, err
}
//nolint:gosec
config := &tls.Config{
InsecureSkipVerify: false,
RootCAs: rootCA,
ServerName: serverName,
Certificates: []tls.Certificate{cert},
}
return config, nil
}

View File

@ -1,56 +0,0 @@
package credentials
import (
"testing"
"github.com/stretchr/testify/assert"
)
var TestCACert = `-----BEGIN CERTIFICATE-----
MIIBjjCCATOgAwIBAgIQdZeGNuAHZhXSmb37Pnx2QzAKBggqhkjOPQQDAjAYMRYw
FAYDVQQDEw1jbHVzdGVyLmxvY2FsMB4XDTIwMDIwMTAwMzUzNFoXDTMwMDEyOTAw
MzUzNFowGDEWMBQGA1UEAxMNY2x1c3Rlci5sb2NhbDBZMBMGByqGSM49AgEGCCqG
SM49AwEHA0IABAeMFRst4JhcFpebfgEs1MvJdD7h5QkCbLwChRHVEUoaDqd1aYjm
bX5SuNBXz5TBEhHfTV3Objh6LQ2N+CBoCeOjXzBdMA4GA1UdDwEB/wQEAwIBBjAS
BgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBRBWthv5ZQ3vALl2zXWwAXSmZ+m
qTAYBgNVHREEETAPgg1jbHVzdGVyLmxvY2FsMAoGCCqGSM49BAMCA0kAMEYCIQDN
rQNOck4ENOhmLROE/wqH0MKGjE6P8yzesgnp9fQI3AIhAJaVPrZloxl1dWCgmNWo
Iklq0JnMgJU7nS+VpVvlgBN8
-----END CERTIFICATE-----`
var TestCert = `-----BEGIN CERTIFICATE-----
MIIBXDCCAQOgAwIBAgIRALFHPINM7m/sHbH775ZjtGYwCgYIKoZIzj0EAwIwKjEX
MBUGA1UEChMOZGFwci5pby9zZW50cnkxDzANBgNVBAMTBnNlbnRyeTAeFw0yMDAy
MTEwMDQ1NThaFw0yMTAyMTAwMDQ1NThaMBExDzANBgNVBAMTBnNlbnRyeTBZMBMG
ByqGSM49AgEGCCqGSM49AwEHA0IABK4QF+h1jJDBnXcWc4lwewgq+4fcb7Ud6SSx
FEiiaOTSsZfb/IY0T8VGLHSalc1jFlCfD8mNuhjx9QTgR6QPRwGjIzAhMA4GA1Ud
DwEB/wQEAwIBBjAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMCA0cAMEQCIBk1
k8Cu51NLvo2esE4YvA65fzjYIo7hC7JjQJ107QARAiAnbsZu/InV17eJWTohNSPB
hIzOUyB1HWO0KobCoOPGPQ==
-----END CERTIFICATE-----`
var TestKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIqwzdYX+5OM7qeU3sWCApUdyK35q11i3ma1JmcRHxcJoAoGCCqGSM49
AwEHoUQDQgAErhAX6HWMkMGddxZziXB7CCr7h9xvtR3pJLEUSKJo5NKxl9v8hjRP
xUYsdJqVzWMWUJ8PyY26GPH1BOBHpA9HAQ==
-----END EC PRIVATE KEY-----`
func TestTLSConfigFromCertAndKey(t *testing.T) {
t.Run("invalid cert", func(t *testing.T) {
conf, err := TLSConfigFromCertAndKey(nil, []byte(TestKey), "server", nil)
assert.NotNil(t, err)
assert.Nil(t, conf)
})
t.Run("invalid key", func(t *testing.T) {
conf, err := TLSConfigFromCertAndKey([]byte(TestCert), nil, "server", nil)
assert.NotNil(t, err)
assert.Nil(t, conf)
})
t.Run("valid cert and keys", func(t *testing.T) {
conf, err := TLSConfigFromCertAndKey([]byte(TestCert), []byte(TestKey), "server", nil)
assert.Nil(t, err)
assert.NotNil(t, conf)
})
}

View File

@ -10,6 +10,7 @@ import (
"go.opencensus.io/tag"
diagUtils "github.com/dapr/dapr/pkg/diagnostics/utils"
"github.com/dapr/dapr/pkg/security/spiffe"
)
// Tag keys.
@ -415,15 +416,15 @@ func (s *serviceMetrics) ReportActorPendingCalls(actorType string, pendingLocks
}
// RequestAllowedByAppAction records the requests allowed due to a match with the action specified in the access control policy for the app.
func (s *serviceMetrics) RequestAllowedByAppAction(appID, trustDomain, namespace, operation, httpverb string, policyAction bool) {
func (s *serviceMetrics) RequestAllowedByAppAction(spiffeID *spiffe.Parsed, operation, httpverb string, policyAction bool) {
if s.enabled {
stats.RecordWithTags(
s.ctx,
diagUtils.WithTags(
s.appPolicyActionAllowed.Name(),
appIDKey, appID,
trustDomainKey, trustDomain,
namespaceKey, namespace,
appIDKey, spiffeID.AppID(),
trustDomainKey, spiffeID.TrustDomain().String(),
namespaceKey, spiffeID.Namespace(),
operationKey, operation,
httpMethodKey, httpverb,
policyActionKey, policyAction),
@ -432,15 +433,15 @@ func (s *serviceMetrics) RequestAllowedByAppAction(appID, trustDomain, namespace
}
// RequestBlockedByAppAction records the requests blocked due to a match with the action specified in the access control policy for the app.
func (s *serviceMetrics) RequestBlockedByAppAction(appID, trustDomain, namespace, operation, httpverb string, policyAction bool) {
func (s *serviceMetrics) RequestBlockedByAppAction(spiffeID *spiffe.Parsed, operation, httpverb string, policyAction bool) {
if s.enabled {
stats.RecordWithTags(
s.ctx,
diagUtils.WithTags(
s.appPolicyActionBlocked.Name(),
appIDKey, appID,
trustDomainKey, trustDomain,
namespaceKey, namespace,
appIDKey, spiffeID.AppID(),
trustDomainKey, spiffeID.TrustDomain().String(),
namespaceKey, spiffeID.Namespace(),
operationKey, operation,
httpMethodKey, httpverb,
policyActionKey, policyAction),
@ -449,15 +450,15 @@ func (s *serviceMetrics) RequestBlockedByAppAction(appID, trustDomain, namespace
}
// RequestAllowedByGlobalAction records the requests allowed due to a match with the global action in the access control policy.
func (s *serviceMetrics) RequestAllowedByGlobalAction(appID, trustDomain, namespace, operation, httpverb string, policyAction bool) {
func (s *serviceMetrics) RequestAllowedByGlobalAction(spiffeID *spiffe.Parsed, operation, httpverb string, policyAction bool) {
if s.enabled {
stats.RecordWithTags(
s.ctx,
diagUtils.WithTags(
s.globalPolicyActionAllowed.Name(),
appIDKey, appID,
trustDomainKey, trustDomain,
namespaceKey, namespace,
appIDKey, spiffeID.AppID(),
trustDomainKey, spiffeID.TrustDomain().String(),
namespaceKey, spiffeID.Namespace(),
operationKey, operation,
httpMethodKey, httpverb,
policyActionKey, policyAction),
@ -466,15 +467,15 @@ func (s *serviceMetrics) RequestAllowedByGlobalAction(appID, trustDomain, namesp
}
// RequestBlockedByGlobalAction records the requests blocked due to a match with the global action in the access control policy.
func (s *serviceMetrics) RequestBlockedByGlobalAction(appID, trustDomain, namespace, operation, httpverb string, policyAction bool) {
func (s *serviceMetrics) RequestBlockedByGlobalAction(spiffeID *spiffe.Parsed, operation, httpverb string, policyAction bool) {
if s.enabled {
stats.RecordWithTags(
s.ctx,
diagUtils.WithTags(
s.globalPolicyActionBlocked.Name(),
appIDKey, appID,
trustDomainKey, trustDomain,
namespaceKey, namespace,
appIDKey, spiffeID.AppID(),
trustDomainKey, spiffeID.TrustDomain().String(),
namespaceKey, spiffeID.Namespace(),
operationKey, operation,
httpMethodKey, httpverb,
policyActionKey, policyAction),

View File

@ -32,7 +32,7 @@ import (
"github.com/dapr/dapr/pkg/config"
diag "github.com/dapr/dapr/pkg/diagnostics"
"github.com/dapr/dapr/pkg/modes"
"github.com/dapr/dapr/pkg/runtime/security"
"github.com/dapr/dapr/pkg/security"
)
const (
@ -60,33 +60,29 @@ type AppChannelConfig struct {
// Manager is a wrapper around gRPC connection pooling.
type Manager struct {
remoteConns *RemoteConnectionPool
auth security.Authenticator
mode modes.DaprMode
channelConfig *AppChannelConfig
localConn *ConnectionPool
localConnLock sync.RWMutex
appClientConn grpc.ClientConnInterface
sec security.Handler
wg sync.WaitGroup
closed atomic.Bool
closeCh chan struct{}
}
// NewManager returns a new grpc manager.
func NewManager(mode modes.DaprMode, channelConfig *AppChannelConfig) *Manager {
func NewManager(sec security.Handler, mode modes.DaprMode, channelConfig *AppChannelConfig) *Manager {
return &Manager{
remoteConns: NewRemoteConnectionPool(),
mode: mode,
channelConfig: channelConfig,
localConn: NewConnectionPool(maxConnIdle, 1),
sec: sec,
closeCh: make(chan struct{}),
}
}
// SetAuthenticator sets the gRPC manager a tls authenticator context.
func (g *Manager) SetAuthenticator(auth security.Authenticator) {
g.auth = auth
}
// GetAppChannel returns a connection to the local channel.
// If there's no active connection to the app, it creates one.
func (g *Manager) GetAppChannel() (channel.AppChannel, error) {
@ -202,29 +198,7 @@ func (g *Manager) connectRemote(
)
}
if g.auth != nil {
signedCert := g.auth.GetCurrentSignedCert()
var cert tls.Certificate
cert, err = tls.X509KeyPair(signedCert.WorkloadCert, signedCert.PrivateKeyPem)
if err != nil {
return nil, fmt.Errorf("error loading x509 Key Pair: %w", err)
}
var serverName string
if id != "cluster.local" {
serverName = id + "." + namespace + ".svc.cluster.local"
}
//nolint:gosec
ta := credentials.NewTLS(&tls.Config{
ServerName: serverName,
Certificates: []tls.Certificate{cert},
RootCAs: signedCert.TrustChain,
})
opts = append(opts, grpc.WithTransportCredentials(ta))
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
opts = append(opts, g.sec.GRPCDialOptionMTLSUnknownTrustDomain(namespace, id))
opts = append(opts, customOpts...)

View File

@ -14,47 +14,23 @@ limitations under the License.
package manager
import (
"crypto/x509"
"testing"
"github.com/stretchr/testify/assert"
"github.com/dapr/dapr/pkg/modes"
"github.com/dapr/dapr/pkg/runtime/security"
)
type authenticatorMock struct{}
func (a *authenticatorMock) GetTrustAnchors() *x509.CertPool {
return nil
}
func (a *authenticatorMock) GetCurrentSignedCert() *security.SignedCertificate {
return nil
}
func (a *authenticatorMock) CreateSignedWorkloadCert(id, namespace, trustDomain string) (*security.SignedCertificate, error) {
return nil, nil
}
func TestNewManager(t *testing.T) {
t.Run("with self hosted", func(t *testing.T) {
m := NewManager(modes.StandaloneMode, &AppChannelConfig{})
m := NewManager(nil, modes.StandaloneMode, &AppChannelConfig{})
assert.NotNil(t, m)
assert.Equal(t, modes.StandaloneMode, m.mode)
})
t.Run("with kubernetes", func(t *testing.T) {
m := NewManager(modes.KubernetesMode, &AppChannelConfig{})
m := NewManager(nil, modes.KubernetesMode, &AppChannelConfig{})
assert.NotNil(t, m)
assert.Equal(t, modes.KubernetesMode, m.mode)
})
}
func TestSetAuthenticator(t *testing.T) {
a := &authenticatorMock{}
m := NewManager(modes.StandaloneMode, &AppChannelConfig{})
m.SetAuthenticator(a)
assert.Equal(t, a, m.auth)
}

View File

@ -15,7 +15,6 @@ package grpc
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
@ -27,7 +26,7 @@ import (
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/dapr/dapr/pkg/config"
@ -37,7 +36,6 @@ import (
"github.com/dapr/dapr/pkg/messaging"
internalv1pb "github.com/dapr/dapr/pkg/proto/internals/v1"
runtimev1pb "github.com/dapr/dapr/pkg/proto/runtime/v1"
rtSecurity "github.com/dapr/dapr/pkg/runtime/security"
"github.com/dapr/dapr/pkg/runtime/wfengine"
"github.com/dapr/dapr/pkg/security"
securityConsts "github.com/dapr/dapr/pkg/security/consts"
@ -59,27 +57,23 @@ type Server interface {
}
type server struct {
api API
config ServerConfig
tracingSpec config.TracingSpec
metricSpec config.MetricSpec
authenticator rtSecurity.Authenticator
servers []*grpc.Server
renewMutex sync.Mutex
signedCert *rtSecurity.SignedCertificate
tlsCert tls.Certificate
signedCertDuration time.Duration
kind string
logger logger.Logger
infoLogger logger.Logger
maxConnectionAge *time.Duration
authToken string
apiSpec config.APISpec
proxy messaging.Proxy
workflowEngine *wfengine.WorkflowEngine
wg sync.WaitGroup
closed atomic.Bool
closeCh chan struct{}
api API
config ServerConfig
tracingSpec config.TracingSpec
metricSpec config.MetricSpec
servers []*grpc.Server
kind string
logger logger.Logger
infoLogger logger.Logger
maxConnectionAge *time.Duration
authToken string
apiSpec config.APISpec
proxy messaging.Proxy
workflowEngine *wfengine.WorkflowEngine
sec security.Handler
wg sync.WaitGroup
closed atomic.Bool
closeCh chan struct{}
}
var (
@ -108,17 +102,17 @@ func NewAPIServer(api API, config ServerConfig, tracingSpec config.TracingSpec,
}
// NewInternalServer returns a new gRPC server for Dapr to Dapr communications.
func NewInternalServer(api API, config ServerConfig, tracingSpec config.TracingSpec, metricSpec config.MetricSpec, authenticator rtSecurity.Authenticator, proxy messaging.Proxy) Server {
func NewInternalServer(api API, config ServerConfig, tracingSpec config.TracingSpec, metricSpec config.MetricSpec, sec security.Handler, proxy messaging.Proxy) Server {
return &server{
api: api,
config: config,
tracingSpec: tracingSpec,
metricSpec: metricSpec,
authenticator: authenticator,
kind: internalServer,
logger: internalServerLogger,
maxConnectionAge: getDefaultMaxAgeDuration(),
proxy: proxy,
sec: sec,
closeCh: make(chan struct{}),
}
}
@ -212,25 +206,6 @@ func (s *server) Close() error {
return nil
}
func (s *server) generateWorkloadCert() error {
s.logger.Info("sending workload csr request to sentry")
signedCert, err := s.authenticator.CreateSignedWorkloadCert(s.config.AppID, s.config.NameSpace, s.config.TrustDomain)
if err != nil {
return fmt.Errorf("error from authenticator CreateSignedWorkloadCert: %w", err)
}
s.logger.Info("certificate signed successfully")
tlsCert, err := tls.X509KeyPair(signedCert.WorkloadCert, signedCert.PrivateKeyPem)
if err != nil {
return fmt.Errorf("error creating x509 Key Pair: %w", err)
}
s.signedCert = signedCert
s.tlsCert = tlsCert
s.signedCertDuration = signedCert.Expiry.Sub(time.Now().UTC())
return nil
}
func (s *server) getMiddlewareOptions() []grpc.ServerOption {
intr := make([]grpc.UnaryServerInterceptor, 0, 6)
intrStream := make([]grpc.StreamServerInterceptor, 0, 5)
@ -289,42 +264,18 @@ func (s *server) getGRPCServer() (*grpc.Server, error) {
opts = append(opts, grpc.KeepaliveParams(keepalive.ServerParameters{MaxConnectionAge: *s.maxConnectionAge}))
}
if s.authenticator != nil {
err := s.generateWorkloadCert()
if err != nil {
return nil, err
}
//nolint:gosec
tlsConfig := tls.Config{
ClientCAs: s.signedCert.TrustChain,
ClientAuth: tls.RequireAndVerifyClientCert,
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return &s.tlsCert, nil
},
}
// In the internal server, enforce minimum version TLS 1.2
if s.kind == internalServer {
tlsConfig.MinVersion = tls.VersionTLS12
}
ta := credentials.NewTLS(&tlsConfig)
opts = append(opts, grpc.Creds(ta))
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.startWorkloadCertRotation()
}()
}
opts = append(opts,
grpc.MaxRecvMsgSize(s.config.MaxRequestBodySizeMB<<20),
grpc.MaxSendMsgSize(s.config.MaxRequestBodySizeMB<<20),
grpc.MaxHeaderListSize(uint32(s.config.ReadBufferSizeKB<<10)),
)
if s.sec == nil {
opts = append(opts, grpc.Creds(insecure.NewCredentials()))
} else {
opts = append(opts, s.sec.GRPCServerOptionMTLS())
}
if s.proxy != nil {
opts = append(opts, grpc.UnknownServiceHandler(s.proxy.Handler()))
}
@ -332,44 +283,6 @@ func (s *server) getGRPCServer() (*grpc.Server, error) {
return grpc.NewServer(opts...), nil
}
func (s *server) startWorkloadCertRotation() {
s.logger.Infof("starting workload cert expiry watcher. current cert expires on: %s", s.signedCert.Expiry.String())
ticker := time.NewTicker(certWatchInterval)
defer ticker.Stop()
for {
select {
case <-s.closeCh:
return
case <-ticker.C:
s.renewMutex.Lock()
renew := shouldRenewCert(s.signedCert.Expiry, s.signedCertDuration)
if renew {
s.logger.Info("renewing certificate: requesting new cert and restarting gRPC server")
err := s.generateWorkloadCert()
if err != nil {
s.logger.Errorf("error starting server: %s", err)
s.renewMutex.Unlock()
continue
}
diag.DefaultMonitoring.MTLSWorkLoadCertRotationCompleted()
}
s.renewMutex.Unlock()
}
}
}
func shouldRenewCert(certExpiryDate time.Time, certDuration time.Duration) bool {
expiresIn := certExpiryDate.Sub(time.Now())
expiresInSeconds := expiresIn.Seconds()
certDurationSeconds := certDuration.Seconds()
percentagePassed := 100 - ((expiresInSeconds / certDurationSeconds) * 100)
return percentagePassed >= renewWhenPercentagePassed
}
func (s *server) getGRPCAPILoggingMiddlewares() (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor) {
if s.infoLogger == nil {
return nil, nil

View File

@ -25,25 +25,6 @@ import (
"github.com/dapr/kit/logger"
)
func TestCertRenewal(t *testing.T) {
t.Run("shouldn't renew", func(t *testing.T) {
certExpiry := time.Now().Add(time.Hour * 2).UTC()
certDuration := certExpiry.Sub(time.Now().UTC())
renew := shouldRenewCert(certExpiry, certDuration)
assert.False(t, renew)
})
t.Run("should renew", func(t *testing.T) {
certExpiry := time.Now().Add(time.Second * 3).UTC()
certDuration := certExpiry.Sub(time.Now().UTC())
time.Sleep(time.Millisecond * 2200)
renew := shouldRenewCert(certExpiry, certDuration)
assert.True(t, renew)
})
}
func TestGetMiddlewareOptions(t *testing.T) {
t.Run("should enable unary interceptor if tracing and metrics are enabled", func(t *testing.T) {
fakeServer := &server{

View File

@ -36,7 +36,6 @@ type SidecarConfig struct {
Mode injectorConsts.DaprMode `default:"kubernetes"`
Namespace string
TrustAnchors string
CertChain string
CertKey string
MTLSEnabled bool
@ -49,6 +48,9 @@ type SidecarConfig struct {
ReadOnlyRootFilesystem bool
SidecarDropALLCapabilities bool
DisableTokenVolume bool
CurrentTrustAnchors []byte
ControlPlaneNamespace string
ControlPlaneTrustDomain string
SidecarHTTPPort int32 `default:"3500"`
SidecarAPIGRPCPort int32 `default:"50001"`
SidecarInternalGRPCPort int32 `default:"50002"`

View File

@ -229,6 +229,19 @@ func (c *SidecarConfig) getSidecarContainer(opts getSidecarContainerOpts) (*core
},
},
},
{
Name: securityConsts.TrustAnchorsEnvVar,
Value: string(c.CurrentTrustAnchors),
},
// TODO: @joshvanl: In v1.14, this two env vars should be moved to flags.
{
Name: securityConsts.ControlPlaneNamespaceEnvVar,
Value: c.ControlPlaneNamespace,
},
{
Name: securityConsts.ControlPlaneTrustDomainEnvVar,
Value: c.ControlPlaneTrustDomain,
},
},
VolumeMounts: opts.VolumeMounts,
ReadinessProbe: &corev1.Probe{
@ -247,6 +260,24 @@ func (c *SidecarConfig) getSidecarContainer(opts getSidecarContainerOpts) (*core
},
}
// TODO: @joshvanl: included for backwards compatibility with v1.11 daprd's
// which request these environment variables to be present when running in
// Kubernetes. Should be removed in v1.13.
container.Env = append(container.Env,
corev1.EnvVar{
Name: securityConsts.CertChainEnvVar,
Value: c.CertChain,
},
corev1.EnvVar{
Name: securityConsts.CertKeyEnvVar,
Value: c.CertKey,
},
corev1.EnvVar{
Name: "SENTRY_LOCAL_IDENTITY",
Value: c.Identity,
},
)
// If the pod contains any of the tolerations specified by the configuration,
// the Command and Args are passed as is. Otherwise, the Command is passed as a part of Args.
// This is to allow the Docker images to specify an ENTRYPOINT
@ -296,25 +327,6 @@ func (c *SidecarConfig) getSidecarContainer(opts getSidecarContainerOpts) (*core
})
}
container.Env = append(container.Env,
corev1.EnvVar{
Name: securityConsts.TrustAnchorsEnvVar,
Value: c.TrustAnchors,
},
corev1.EnvVar{
Name: securityConsts.CertChainEnvVar,
Value: c.CertChain,
},
corev1.EnvVar{
Name: securityConsts.CertKeyEnvVar,
Value: c.CertKey,
},
corev1.EnvVar{
Name: "SENTRY_LOCAL_IDENTITY",
Value: c.Identity,
},
)
if c.APITokenSecret != "" {
container.Env = append(container.Env, corev1.EnvVar{
Name: securityConsts.APITokenEnvVar,

View File

@ -326,6 +326,10 @@ func TestGetSidecarContainer(t *testing.T) {
c.SentryAddress = "sentry:50000"
c.MTLSEnabled = true
c.Identity = "pod_identity"
c.ControlPlaneNamespace = "my-namespace"
c.ControlPlaneTrustDomain = "test.example.com"
c.CertChain = "my-cert-chain"
c.CertKey = "my-cert-key"
c.SetFromPodAnnotations()
@ -361,10 +365,20 @@ func TestGetSidecarContainer(t *testing.T) {
assert.Equal(t, "dapr-system", container.Env[0].Value)
// POD_NAME
assert.Equal(t, "metadata.name", container.Env[1].ValueFrom.FieldRef.FieldPath)
// DAPR_CONTROLPLANE_NAMESPACE
assert.Equal(t, "my-namespace", container.Env[3].Value)
// DAPR_CONTROLPLANE_TRUST_DOMAIN
assert.Equal(t, "test.example.com", container.Env[4].Value)
// DAPR_CERT_CHAIN
assert.Equal(t, "my-cert-chain", container.Env[5].Value)
// DAPR_CERT_KEY
assert.Equal(t, "my-cert-key", container.Env[6].Value)
// SENTRY_LOCAL_IDENTITY
assert.Equal(t, "pod_identity", container.Env[7].Value)
// DAPR_API_TOKEN
assert.Equal(t, "secret", container.Env[6].ValueFrom.SecretKeyRef.Name)
assert.Equal(t, "secret", container.Env[8].ValueFrom.SecretKeyRef.Name)
// DAPR_APP_TOKEN
assert.Equal(t, "appsecret", container.Env[7].ValueFrom.SecretKeyRef.Name)
assert.Equal(t, "appsecret", container.Env[9].ValueFrom.SecretKeyRef.Name)
// default image
assert.Equal(t, "daprio/dapr", container.Image)
assert.EqualValues(t, expectedArgs, container.Args)
@ -394,6 +408,10 @@ func TestGetSidecarContainer(t *testing.T) {
c.SentryAddress = "sentry:50000"
c.MTLSEnabled = true
c.Identity = "pod_identity"
c.ControlPlaneNamespace = "my-namespace"
c.ControlPlaneTrustDomain = "test.example.com"
c.CertChain = "my-cert-chain"
c.CertKey = "my-cert-key"
c.SetFromPodAnnotations()
@ -437,10 +455,20 @@ func TestGetSidecarContainer(t *testing.T) {
assert.Equal(t, "dapr-system", container.Env[0].Value)
// POD_NAME
assert.Equal(t, "metadata.name", container.Env[1].ValueFrom.FieldRef.FieldPath)
// DAPR_CONTROLPLANE_NAMESPACE
assert.Equal(t, "my-namespace", container.Env[3].Value)
// DAPR_CONTROLPLANE_TRUST_DOMAIN
assert.Equal(t, "test.example.com", container.Env[4].Value)
// DAPR_CERT_CHAIN
assert.Equal(t, "my-cert-chain", container.Env[5].Value)
// DAPR_CERT_KEY
assert.Equal(t, "my-cert-key", container.Env[6].Value)
// SENTRY_LOCAL_IDENTITY
assert.Equal(t, "pod_identity", container.Env[7].Value)
// DAPR_API_TOKEN
assert.Equal(t, "secret", container.Env[6].ValueFrom.SecretKeyRef.Name)
assert.Equal(t, "secret", container.Env[8].ValueFrom.SecretKeyRef.Name)
// DAPR_APP_TOKEN
assert.Equal(t, "appsecret", container.Env[7].ValueFrom.SecretKeyRef.Name)
assert.Equal(t, "appsecret", container.Env[9].ValueFrom.SecretKeyRef.Name)
// default image
assert.Equal(t, "daprio/dapr", container.Image)
assert.EqualValues(t, expectedArgs, container.Args)

View File

@ -267,6 +267,8 @@ func TestPatching(t *testing.T) {
c := NewSidecarConfig(pod)
c.Namespace = "testns"
c.Identity = "pod:identity"
c.CertChain = "certchain"
c.CertKey = "certkey"
if tc.sidecarConfigModifierFn != nil {
tc.sidecarConfigModifierFn(c)

View File

@ -0,0 +1,114 @@
/*
Copyright 2023 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sentry
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"os"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"google.golang.org/grpc"
sentryv1pb "github.com/dapr/dapr/pkg/proto/sentry/v1"
"github.com/dapr/dapr/pkg/security"
securitytoken "github.com/dapr/dapr/pkg/security/token"
)
// Options contains the configuration options for connecting and requesting
// certificates from sentry.
type Options struct {
SentryAddress string
SentryID spiffeid.ID
Security security.Handler
}
// Requester is used to request certificates from the sentry service for any
// daprd identity.
type Requester struct {
sentryAddress string
sentryID spiffeid.ID
sec security.Handler
kubernetesMode bool
}
// New returns a new instance of the Requester.
func New(opts Options) *Requester {
_, kubeMode := os.LookupEnv("KUBERNETES_SERVICE_HOST")
return &Requester{
sentryAddress: opts.SentryAddress,
sentryID: opts.SentryID,
sec: opts.Security,
kubernetesMode: kubeMode,
}
}
// RequestCertificateFromSentry requests a certificate from sentry for a
// generic daprd identity in a namespace.
// Returns the signed certificate chain and leaf private key as a PEM encoded
// byte slice.
func (r *Requester) RequestCertificateFromSentry(ctx context.Context, namespace string) ([]byte, []byte, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
}
csrDER, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{
Subject: pkix.Name{CommonName: "_unknown"},
DNSNames: []string{"_unknown"},
}, key)
if err != nil {
return nil, nil, fmt.Errorf("failed to create sidecar csr: %w", err)
}
conn, err := grpc.DialContext(ctx, r.sentryAddress, r.sec.GRPCDialOptionMTLS(r.sentryID))
if err != nil {
return nil, nil, fmt.Errorf("error establishing connection to sentry: %w", err)
}
defer conn.Close()
token, tokenValidator, err := securitytoken.GetSentryToken(r.kubernetesMode)
if err != nil {
return nil, nil, fmt.Errorf("error obtaining token: %w", err)
}
resp, err := sentryv1pb.NewCAClient(conn).SignCertificate(ctx,
&sentryv1pb.SignCertificateRequest{
CertificateSigningRequest: pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST", Bytes: csrDER,
}),
Id: "_unknown",
Token: token,
Namespace: namespace,
TokenValidator: tokenValidator,
})
if err != nil {
return nil, nil, fmt.Errorf("error from sentry SignCertificate: %w", err)
}
keyCS8, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
}
return resp.WorkloadCertificate, pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY", Bytes: keyCS8,
}), nil
}

View File

@ -24,8 +24,6 @@ import (
// Config represents configuration options for the Dapr Sidecar Injector webhook server.
type Config struct {
TLSCertFile string `envconfig:"TLS_CERT_FILE" required:"true"`
TLSKeyFile string `envconfig:"TLS_KEY_FILE" required:"true"`
SidecarImage string `envconfig:"SIDECAR_IMAGE" required:"true"`
SidecarImagePullPolicy string `envconfig:"SIDECAR_IMAGE_PULL_POLICY"`
Namespace string `envconfig:"NAMESPACE" required:"true"`
@ -38,6 +36,10 @@ type Config struct {
ReadOnlyRootFilesystem string `envconfig:"SIDECAR_READ_ONLY_ROOT_FILESYSTEM"`
SidecarDropALLCapabilities string `envconfig:"SIDECAR_DROP_ALL_CAPABILITIES"`
TrustAnchorsFile string `envconfig:"DAPR_TRUST_ANCHORS_FILE"`
ControlPlaneTrustDomain string `envconfig:"DAPR_CONTROL_PLANE_TRUST_DOMAIN"`
SentryAddress string `envconfig:"DAPR_SENTRY_ADDRESS"`
parsedEntrypointTolerations []corev1.Toleration
}
@ -46,7 +48,9 @@ type Config struct {
// and/or override default values.
func NewConfigWithDefaults() Config {
return Config{
SidecarImagePullPolicy: "Always",
SidecarImagePullPolicy: "Always",
ControlPlaneTrustDomain: "cluster.local",
TrustAnchorsFile: "/var/run/dapr.io/tls/ca.crt",
}
}

View File

@ -33,8 +33,6 @@ func TestGetInjectorConfig(t *testing.T) {
cfg, err := GetConfig()
assert.NoError(t, err)
assert.Equal(t, "test-cert-file", cfg.TLSCertFile)
assert.Equal(t, "test-key-file", cfg.TLSKeyFile)
assert.Equal(t, "daprd-test-image", cfg.SidecarImage)
assert.Equal(t, "Always", cfg.SidecarImagePullPolicy)
assert.Equal(t, "test-namespace", cfg.Namespace)
@ -53,8 +51,6 @@ func TestGetInjectorConfig(t *testing.T) {
cfg, err := GetConfig()
assert.NoError(t, err)
assert.Equal(t, "test-cert-file", cfg.TLSCertFile)
assert.Equal(t, "test-key-file", cfg.TLSKeyFile)
assert.Equal(t, "daprd-test-image", cfg.SidecarImage)
assert.Equal(t, "IfNotPresent", cfg.SidecarImagePullPolicy)
assert.Equal(t, "test-namespace", cfg.Namespace)

View File

@ -15,6 +15,7 @@ package service
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@ -37,16 +38,26 @@ import (
func TestHandleRequest(t *testing.T) {
authID := "test-auth-id"
i, err := NewInjector([]string{authID}, Config{
TLSCertFile: "test-cert",
TLSKeyFile: "test-key",
SidecarImage: "test-image",
Namespace: "test-ns",
AllowedServiceAccountsPrefixNames: "vc-proj*:sa-dev*,vc-all-allowed*:*",
}, fake.NewSimpleClientset(), kubernetesfake.NewSimpleClientset())
i, err := NewInjector(Options{
AuthUIDs: []string{authID},
Config: Config{
SidecarImage: "test-image",
Namespace: "test-ns",
ControlPlaneTrustDomain: "test-trust-domain",
AllowedServiceAccountsPrefixNames: "vc-proj*:sa-dev*,vc-all-allowed*:*",
},
DaprClient: fake.NewSimpleClientset(),
KubeClient: kubernetesfake.NewSimpleClientset(),
})
assert.NoError(t, err)
injector := i.(*injector)
injector.currentTrustAnchors = func() ([]byte, error) {
return nil, nil
}
injector.signDaprdCertificate = func(context.Context, string) ([]byte, []byte, error) {
return []byte("test-cert"), []byte("test-key"), nil
}
podBytes, _ := json.Marshal(corev1.Pod{
TypeMeta: metav1.TypeMeta{

View File

@ -19,7 +19,6 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"strings"
"time"
@ -56,12 +55,27 @@ var AllowedServiceAccountInfos = []string{
"tekton-pipelines:tekton-pipelines-controller",
}
type (
signDaprdCertificateFn func(ctx context.Context, namespace string) (cert []byte, key []byte, err error)
currentTrustAnchorsFn func() (ca []byte, err error)
)
// Injector is the interface for the Dapr runtime sidecar injection component.
type Injector interface {
Run(context.Context) error
Run(context.Context, *tls.Config, signDaprdCertificateFn, currentTrustAnchorsFn) error
Ready(context.Context) error
}
type Options struct {
AuthUIDs []string
Config Config
DaprClient scheme.Interface
KubeClient kubernetes.Interface
ControlPlaneNamespace string
ControlPlaneTrustDomain string
}
type injector struct {
config Config
deserializer runtime.Decoder
@ -70,6 +84,11 @@ type injector struct {
daprClient scheme.Interface
authUIDs []string
controlPlaneNamespace string
controlPlaneTrustDomain string
currentTrustAnchors currentTrustAnchorsFn
signDaprdCertificate signDaprdCertificateFn
namespaceNameMatcher *namespacednamematcher.EqualPrefixNameNamespaceMatcher
ready chan struct{}
}
@ -110,29 +129,28 @@ func getAppIDFromRequest(req *admissionv1.AdmissionRequest) (appID string) {
}
// NewInjector returns a new Injector instance with the given config.
func NewInjector(authUIDs []string, config Config, daprClient scheme.Interface, kubeClient kubernetes.Interface) (Injector, error) {
func NewInjector(opts Options) (Injector, error) {
mux := http.NewServeMux()
i := &injector{
config: config,
config: opts.Config,
deserializer: serializer.NewCodecFactory(
runtime.NewScheme(),
).UniversalDeserializer(),
server: &http.Server{
Addr: fmt.Sprintf(":%d", port),
Handler: mux,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
Addr: fmt.Sprintf(":%d", port),
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
},
kubeClient: kubeClient,
daprClient: daprClient,
authUIDs: authUIDs,
ready: make(chan struct{}),
kubeClient: opts.KubeClient,
daprClient: opts.DaprClient,
authUIDs: opts.AuthUIDs,
controlPlaneNamespace: opts.ControlPlaneNamespace,
controlPlaneTrustDomain: opts.ControlPlaneTrustDomain,
ready: make(chan struct{}),
}
matcher, err := createNamespaceNameMatcher(config.AllowedServiceAccountsPrefixNames)
matcher, err := createNamespaceNameMatcher(opts.Config.AllowedServiceAccountsPrefixNames)
if err != nil {
return nil, err
}
@ -195,7 +213,7 @@ func getServiceAccount(ctx context.Context, kubeClient kubernetes.Interface, all
return allowedUids, nil
}
func (i *injector) Run(ctx context.Context) error {
func (i *injector) Run(ctx context.Context, tlsConfig *tls.Config, signDaprdFn signDaprdCertificateFn, currentTrustAnchors currentTrustAnchorsFn) error {
select {
case <-i.ready:
return errors.New("injector already running")
@ -203,18 +221,17 @@ func (i *injector) Run(ctx context.Context) error {
// Nop
}
ln, err := net.Listen("tcp", i.server.Addr)
if err != nil {
return fmt.Errorf("error while creating listener: %w", err)
}
log.Infof("Sidecar injector is listening on %s, patching Dapr-enabled pods", i.server.Addr)
i.currentTrustAnchors = currentTrustAnchors
i.signDaprdCertificate = signDaprdFn
i.server.TLSConfig = tlsConfig
errCh := make(chan error, 1)
go func() {
srverr := i.server.ServeTLS(ln, i.config.TLSCertFile, i.config.TLSKeyFile)
if !errors.Is(srverr, http.ErrServerClosed) {
errCh <- fmt.Errorf("sidecar injector error: %s", srverr)
err := i.server.ListenAndServeTLS("", "")
if !errors.Is(err, http.ErrServerClosed) {
errCh <- fmt.Errorf("sidecar injector error: %w", err)
return
}
errCh <- nil
@ -227,11 +244,11 @@ func (i *injector) Run(ctx context.Context) error {
log.Info("Sidecar injector is shutting down")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err = i.server.Shutdown(shutdownCtx); err != nil {
if err := i.server.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("error while shutting down injector: %v; %v", err, <-errCh)
}
return <-errCh
case err = <-errCh:
case err := <-errCh:
return err
}
}

View File

@ -31,19 +31,18 @@ import (
)
func TestConfigCorrectValues(t *testing.T) {
i, err := NewInjector(nil, Config{
TLSCertFile: "a",
TLSKeyFile: "b",
SidecarImage: "c",
SidecarImagePullPolicy: "d",
Namespace: "e",
AllowedServiceAccountsPrefixNames: "ns*:sa,namespace:sa*",
}, nil, nil)
i, err := NewInjector(Options{
Config: Config{
SidecarImage: "c",
SidecarImagePullPolicy: "d",
Namespace: "e",
AllowedServiceAccountsPrefixNames: "ns*:sa,namespace:sa*",
ControlPlaneTrustDomain: "trust.domain",
},
})
assert.NoError(t, err)
injector := i.(*injector)
assert.Equal(t, "a", injector.config.TLSCertFile)
assert.Equal(t, "b", injector.config.TLSKeyFile)
assert.Equal(t, "c", injector.config.SidecarImage)
assert.Equal(t, "d", injector.config.SidecarImagePullPolicy)
assert.Equal(t, "e", injector.config.Namespace)
@ -53,14 +52,14 @@ func TestConfigCorrectValues(t *testing.T) {
}
func TestNewInjectorBadAllowedPrefixedServiceAccountConfig(t *testing.T) {
_, err := NewInjector(nil, Config{
TLSCertFile: "a",
TLSKeyFile: "b",
SidecarImage: "c",
SidecarImagePullPolicy: "d",
Namespace: "e",
AllowedServiceAccountsPrefixNames: "ns*:sa,namespace:sa*sa",
}, nil, nil)
_, err := NewInjector(Options{
Config: Config{
SidecarImage: "c",
SidecarImagePullPolicy: "d",
Namespace: "e",
AllowedServiceAccountsPrefixNames: "ns*:sa,namespace:sa*sa",
},
})
assert.Error(t, err)
}

View File

@ -22,13 +22,10 @@ import (
admissionv1 "k8s.io/api/admission/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
scheme "github.com/dapr/dapr/pkg/client/clientset/versioned"
"github.com/dapr/dapr/pkg/credentials"
injectorConsts "github.com/dapr/dapr/pkg/injector/consts"
"github.com/dapr/dapr/pkg/injector/patcher"
securityConsts "github.com/dapr/dapr/pkg/security/consts"
)
const (
@ -53,17 +50,20 @@ func (i *injector) getPodPatchOperations(ctx context.Context, ar *admissionv1.Ad
sentryAddress := patcher.ServiceAddress(patcher.ServiceSentry, i.config.Namespace, i.config.KubeClusterDomain)
operatorAddress := patcher.ServiceAddress(patcher.ServiceAPI, i.config.Namespace, i.config.KubeClusterDomain)
// Get the TLS credentials
trustAnchors, certChain, certKey := GetTrustAnchorsAndCertChain(ctx, i.kubeClient, i.config.Namespace)
trustAnchors, err := i.currentTrustAnchors()
if err != nil {
return nil, err
}
daprdCert, daprdPrivateKey, err := i.signDaprdCertificate(ctx, pod.Namespace)
if err != nil {
return nil, err
}
// Create the sidecar configuration object from the pod
sidecar := patcher.NewSidecarConfig(pod)
sidecar.GetInjectedComponentContainers = i.getInjectedComponentContainers
sidecar.Mode = injectorConsts.ModeKubernetes
sidecar.Namespace = ar.Request.Namespace
sidecar.TrustAnchors = trustAnchors
sidecar.CertChain = certChain
sidecar.CertKey = certKey
sidecar.MTLSEnabled = mTLSEnabled(i.daprClient)
sidecar.Identity = ar.Request.Namespace + ":" + pod.Spec.ServiceAccountName
sidecar.IgnoreEntrypointTolerations = i.config.GetIgnoreEntrypointTolerations()
@ -73,6 +73,11 @@ func (i *injector) getPodPatchOperations(ctx context.Context, ar *admissionv1.Ad
sidecar.RunAsNonRoot = i.config.GetRunAsNonRoot()
sidecar.ReadOnlyRootFilesystem = i.config.GetReadOnlyRootFilesystem()
sidecar.SidecarDropALLCapabilities = i.config.GetDropCapabilities()
sidecar.ControlPlaneNamespace = i.controlPlaneNamespace
sidecar.ControlPlaneTrustDomain = i.controlPlaneTrustDomain
sidecar.CurrentTrustAnchors = trustAnchors
sidecar.CertChain = string(daprdCert)
sidecar.CertKey = string(daprdPrivateKey)
// Set the placement address unless it's skipped
// Even if the placement is skipped, however,the placement address will still be included if explicitly set in the annotations
@ -113,18 +118,3 @@ func mTLSEnabled(daprClient scheme.Interface) bool {
log.Infof("Dapr system configuration '%s' does not exist; using default value %t for mTLSEnabled", defaultConfig, defaultMtlsEnabled)
return defaultMtlsEnabled
}
// GetTrustAnchorsAndCertChain returns the trust anchor and certs.
func GetTrustAnchorsAndCertChain(ctx context.Context, kubeClient kubernetes.Interface, namespace string) (string, string, string) {
secret, err := kubeClient.CoreV1().
Secrets(namespace).
Get(ctx, securityConsts.TrustBundleK8sSecretName, metav1.GetOptions{})
if err != nil {
return "", "", ""
}
rootCert := secret.Data[credentials.RootCertFilename]
certChain := secret.Data[credentials.IssuerCertFilename]
certKey := secret.Data[credentials.IssuerKeyFilename]
return string(rootCert), string(certChain), string(certKey)
}

View File

@ -259,7 +259,7 @@ func (d *directMessaging) invokeHTTPEndpoint(ctx context.Context, appID, appName
}
func (d *directMessaging) invokeRemote(ctx context.Context, appID, appNamespace, appAddress string, req *invokev1.InvokeMethodRequest) (*invokev1.InvokeMethodResponse, func(destroy bool), error) {
conn, teardown, err := d.connectionCreatorFn(context.TODO(), appAddress, appID, appNamespace)
conn, teardown, err := d.connectionCreatorFn(ctx, appAddress, appID, appNamespace)
if err != nil {
if teardown == nil {
teardown = nopTeardown

View File

@ -90,7 +90,7 @@ func (a *apiServer) Run(ctx context.Context, sec security.Handler) error {
log.Infof("starting gRPC server on port %d", serverPort)
s := grpc.NewServer(sec.GRPCServerOption())
s := grpc.NewServer(sec.GRPCServerOptionMTLS())
operatorv1pb.RegisterOperatorServer(s, a)
lis, err := net.Listen("tcp", fmt.Sprintf(":%v", serverPort))

View File

@ -2,20 +2,16 @@ package client
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"time"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpcRetry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
daprCredentials "github.com/dapr/dapr/pkg/credentials"
diag "github.com/dapr/dapr/pkg/diagnostics"
operatorv1pb "github.com/dapr/dapr/pkg/proto/operator/v1"
"github.com/dapr/dapr/pkg/security"
)
const (
@ -24,13 +20,7 @@ const (
// GetOperatorClient returns a new k8s operator client and the underlying connection.
// If a cert chain is given, a TLS connection will be established.
func GetOperatorClient(ctx context.Context,
address, serverName string, certChain *daprCredentials.CertChain,
) (operatorv1pb.OperatorClient, *grpc.ClientConn, error) {
if certChain == nil {
return nil, nil, errors.New("certificate chain cannot be nil")
}
func GetOperatorClient(ctx context.Context, address string, sec security.Handler) (operatorv1pb.OperatorClient, *grpc.ClientConn, error) {
unaryClientInterceptor := grpcRetry.UnaryClientInterceptor()
if diag.DefaultGRPCMonitoring.IsEnabled() {
@ -40,24 +30,18 @@ func GetOperatorClient(ctx context.Context,
)
}
opts := []grpc.DialOption{grpc.WithUnaryInterceptor(unaryClientInterceptor)}
cp := x509.NewCertPool()
ok := cp.AppendCertsFromPEM(certChain.RootCA)
if !ok {
return nil, nil, errors.New("failed to append PEM root cert to x509 CertPool")
}
config, err := daprCredentials.TLSConfigFromCertAndKey(certChain.Cert, certChain.Key, serverName, cp)
config.MinVersion = tls.VersionTLS12
operatorID, err := spiffeid.FromSegments(sec.ControlPlaneTrustDomain(), "ns", sec.ControlPlaneNamespace(), "dapr-operator")
if err != nil {
return nil, nil, fmt.Errorf("failed to create tls config from cert and key: %w", err)
return nil, nil, err
}
// block for connection
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(config)), grpc.WithBlock(), grpc.WithReturnConnectionError())
ctx, cancelFunc := context.WithTimeout(ctx, dialTimeout)
defer cancelFunc()
opts := []grpc.DialOption{
grpc.WithUnaryInterceptor(unaryClientInterceptor),
sec.GRPCDialOptionMTLS(operatorID), grpc.WithReturnConnectionError(),
}
ctx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
conn, err := grpc.DialContext(ctx, address, opts...)
if err != nil {
return nil, nil, err

View File

@ -230,7 +230,7 @@ func (o *operator) Run(ctx context.Context) error {
caBundleCh := make(chan []byte)
runner := concurrency.NewRunnerManager(
o.secProvider.Start,
o.secProvider.Run,
func(ctx context.Context) error {
// Wait for webhook certificates to be ready before starting the manager.
_, rErr := o.secProvider.Handler(ctx)

View File

@ -148,7 +148,7 @@ func (p *Service) Run(ctx context.Context, port string, sec security.Handler) er
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}
grpcServer := grpc.NewServer(sec.GRPCServerOption())
grpcServer := grpc.NewServer(sec.GRPCServerOptionMTLS())
placementv1pb.RegisterPlacementServer(grpcServer, p)

View File

@ -27,7 +27,6 @@ import (
env "github.com/dapr/dapr/pkg/config/env"
configmodes "github.com/dapr/dapr/pkg/config/modes"
"github.com/dapr/dapr/pkg/config/protocol"
"github.com/dapr/dapr/pkg/credentials"
diag "github.com/dapr/dapr/pkg/diagnostics"
"github.com/dapr/dapr/pkg/metrics"
"github.com/dapr/dapr/pkg/modes"
@ -36,7 +35,7 @@ import (
resiliencyConfig "github.com/dapr/dapr/pkg/resiliency"
rterrors "github.com/dapr/dapr/pkg/runtime/errors"
"github.com/dapr/dapr/pkg/runtime/registry"
"github.com/dapr/dapr/pkg/runtime/security"
"github.com/dapr/dapr/pkg/security"
"github.com/dapr/dapr/pkg/validation"
"github.com/dapr/dapr/utils"
"github.com/dapr/kit/ptr"
@ -105,6 +104,7 @@ type Config struct {
AppChannelAddress string
Metrics *metrics.Options
Registry *registry.Options
Security security.Handler
}
type internalConfig struct {
@ -131,7 +131,6 @@ type internalConfig struct {
enableAPILogging *bool
disableBuiltinK8sSecretStore bool
config []string
certChain *credentials.CertChain
registry *registry.Registry
metricsExporter metrics.Exporter
}
@ -165,18 +164,11 @@ func FromConfig(ctx context.Context, cfg *Config) (*DaprRuntime, error) {
return nil, err
}
if intc.mTLSEnabled || intc.mode == modes.KubernetesMode {
intc.certChain, err = security.GetCertChain()
if err != nil {
return nil, err
}
}
// Config and resiliency need the operator client
var operatorClient operatorV1.OperatorClient
if intc.mode == modes.KubernetesMode {
log.Info("Initializing the operator client")
client, conn, clientErr := client.GetOperatorClient(ctx, intc.kubernetes.ControlPlaneAddress, security.TLSServerName, intc.certChain)
client, conn, clientErr := client.GetOperatorClient(ctx, cfg.ControlPlaneAddress, cfg.Security)
if clientErr != nil {
return nil, clientErr
}
@ -259,7 +251,7 @@ func FromConfig(ctx context.Context, cfg *Config) (*DaprRuntime, error) {
intc.enableAPILogging = ptr.Of(globalConfig.GetAPILoggingSpec().Enabled)
}
return newDaprRuntime(ctx, intc, globalConfig, accessControlList, resiliencyProvider)
return newDaprRuntime(ctx, cfg.Security, intc, globalConfig, accessControlList, resiliencyProvider)
}
func (c *Config) toInternal() (*internalConfig, error) {

View File

@ -15,6 +15,7 @@ package binding
import (
"context"
"crypto/x509"
"io"
"net/http"
"testing"
@ -23,6 +24,7 @@ import (
"github.com/phayes/freeport"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
v1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -39,6 +41,7 @@ import (
"github.com/dapr/dapr/pkg/runtime/meta"
rtmock "github.com/dapr/dapr/pkg/runtime/mock"
"github.com/dapr/dapr/pkg/runtime/registry"
"github.com/dapr/dapr/pkg/security"
daprt "github.com/dapr/dapr/pkg/testing"
testinggrpc "github.com/dapr/dapr/pkg/testing/grpc"
"github.com/dapr/kit/logger"
@ -174,6 +177,21 @@ func TestStartReadingFromBindings(t *testing.T) {
}
func TestGetSubscribedBindingsGRPC(t *testing.T) {
secP, err := security.New(context.Background(), security.Options{
TrustAnchors: []byte("test"),
AppID: "test",
ControlPlaneTrustDomain: "test.example.com",
ControlPlaneNamespace: "default",
MTLSEnabled: false,
OverrideCertRequestSource: func(context.Context, []byte) ([]*x509.Certificate, error) {
return []*x509.Certificate{nil}, nil
},
})
require.NoError(t, err)
go secP.Run(context.Background())
sec, err := secP.Handler(context.Background())
require.NoError(t, err)
testCases := []struct {
name string
expectedResponse []string
@ -199,7 +217,7 @@ func TestGetSubscribedBindingsGRPC(t *testing.T) {
Resiliency: resiliency.New(log),
ComponentStore: compstore.New(),
Meta: meta.New(meta.Options{}),
GRPC: manager.NewManager(modes.StandaloneMode, &manager.AppChannelConfig{Port: port}),
GRPC: manager.NewManager(sec, modes.StandaloneMode, &manager.AppChannelConfig{Port: port}),
})
// create mock application server first
grpcServer := testinggrpc.StartTestAppCallbackGRPCServer(t, port, &channelt.MockServer{

View File

@ -951,7 +951,7 @@ func TestBulkSubscribeGRPC(t *testing.T) {
defer grpcServer.Stop()
}
grpc := manager.NewManager(modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
grpc := manager.NewManager(nil, modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
// create a new AppChannel and gRPC client for every test
ps.channels = channels.New(channels.Options{
@ -1112,7 +1112,7 @@ func TestBulkSubscribeGRPC(t *testing.T) {
}
// create a new AppChannel and gRPC client for every test
grpc := manager.NewManager(modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
grpc := manager.NewManager(nil, modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
ps.channels = channels.New(channels.Options{
ComponentStore: compstore.New(),
Registry: reg,
@ -1219,7 +1219,7 @@ func TestBulkSubscribeGRPC(t *testing.T) {
defer grpcServer.Stop()
}
grpc := manager.NewManager(modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
grpc := manager.NewManager(nil, modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
ps.channels = channels.New(channels.Options{
ComponentStore: compstore.New(),
Registry: reg,
@ -1329,7 +1329,7 @@ func TestBulkSubscribeGRPC(t *testing.T) {
defer grpcServer.Stop()
}
grpc := manager.NewManager(modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
grpc := manager.NewManager(nil, modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
ps.channels = channels.New(channels.Options{
ComponentStore: compstore.New(),
Registry: reg,
@ -1434,7 +1434,7 @@ func TestBulkSubscribeGRPC(t *testing.T) {
defer grpcServer.Stop()
}
grpc := manager.NewManager(modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
grpc := manager.NewManager(nil, modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
ps.channels = channels.New(channels.Options{
ComponentStore: compstore.New(),
Registry: reg,
@ -1525,7 +1525,7 @@ func TestBulkSubscribeGRPC(t *testing.T) {
defer grpcServer.Stop()
}
grpc := manager.NewManager(modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
grpc := manager.NewManager(nil, modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
ps.channels = channels.New(channels.Options{
ComponentStore: compstore.New(),
Registry: reg,

View File

@ -207,7 +207,7 @@ func TestErrorPublishedNonCloudEventGRPC(t *testing.T) {
Mode: modes.StandaloneMode,
Namespace: "ns1",
ID: TestRuntimeConfigID,
GRPC: manager.NewManager(modes.StandaloneMode, &manager.AppChannelConfig{}),
GRPC: manager.NewManager(nil, modes.StandaloneMode, &manager.AppChannelConfig{}),
})
ps.compStore.SetTopicRoutes(map[string]compstore.TopicRoutes{
TestPubsubName: map[string]compstore.TopicRouteElem{
@ -738,7 +738,7 @@ func TestOnNewPublishedMessageGRPC(t *testing.T) {
Mode: modes.StandaloneMode,
Namespace: "ns1",
ID: TestRuntimeConfigID,
GRPC: manager.NewManager(modes.StandaloneMode, &manager.AppChannelConfig{Port: port}),
GRPC: manager.NewManager(nil, modes.StandaloneMode, &manager.AppChannelConfig{Port: port}),
})
ps.compStore.SetTopicRoutes(map[string]compstore.TopicRoutes{
TestPubsubName: map[string]compstore.TopicRouteElem{
@ -768,7 +768,7 @@ func TestOnNewPublishedMessageGRPC(t *testing.T) {
defer grpcServer.Stop()
}
grpc := manager.NewManager(modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
grpc := manager.NewManager(nil, modes.StandaloneMode, &manager.AppChannelConfig{Port: port})
ps.channels = channels.New(channels.Options{
ComponentStore: compstore.New(),
Registry: reg,

View File

@ -69,8 +69,8 @@ import (
"github.com/dapr/dapr/pkg/runtime/meta"
"github.com/dapr/dapr/pkg/runtime/processor"
"github.com/dapr/dapr/pkg/runtime/registry"
"github.com/dapr/dapr/pkg/runtime/security"
"github.com/dapr/dapr/pkg/runtime/wfengine"
"github.com/dapr/dapr/pkg/security"
securityConsts "github.com/dapr/dapr/pkg/security/consts"
"github.com/dapr/dapr/utils"
"github.com/dapr/kit/logger"
@ -119,7 +119,6 @@ type DaprRuntime struct {
nameResolver nr.Resolver
hostAddress string
actorStateStoreLock sync.RWMutex
authenticator security.Authenticator
namespace string
podName string
daprHTTPAPI http.API
@ -133,6 +132,7 @@ type DaprRuntime struct {
compStore *compstore.ComponentStore
processor *processor.Processor
meta *meta.Meta
sec security.Handler
runnerCloser *concurrency.RunnerCloserManager
pendingHTTPEndpoints chan httpEndpointV1alpha1.HTTPEndpoint
@ -156,6 +156,7 @@ type componentPreprocessRes struct {
// newDaprRuntime returns a new runtime with the given runtime config and global config.
func newDaprRuntime(ctx context.Context,
sec security.Handler,
runtimeConfig *internalConfig,
globalConfig *config.Configuration,
accessControlList *config.AccessControlList,
@ -170,12 +171,12 @@ func newDaprRuntime(ctx context.Context,
Mode: runtimeConfig.mode,
})
operatorClient, err := getOperatorClient(ctx, runtimeConfig)
operatorClient, err := getOperatorClient(ctx, sec, runtimeConfig)
if err != nil {
return nil, err
}
grpc := createGRPCManager(runtimeConfig, globalConfig)
grpc := createGRPCManager(sec, runtimeConfig, globalConfig)
wfe := wfengine.NewWorkflowEngine(wfengine.NewWorkflowConfig(runtimeConfig.id))
wfe.ConfigureGrpcExecutor()
@ -207,6 +208,7 @@ func newDaprRuntime(ctx context.Context,
meta: meta,
operatorClient: operatorClient,
channels: channels,
sec: sec,
processor: processor.New(processor.Options{
ID: runtimeConfig.id,
Namespace: getNamespace(),
@ -317,13 +319,13 @@ func getPodName() string {
return os.Getenv("POD_NAME")
}
func getOperatorClient(ctx context.Context, cfg *internalConfig) (operatorv1pb.OperatorClient, error) {
func getOperatorClient(ctx context.Context, sec security.Handler, cfg *internalConfig) (operatorv1pb.OperatorClient, error) {
// Get the operator client only if we're running in Kubernetes and if we need it
if cfg.mode != modes.KubernetesMode {
return nil, nil
}
client, _, err := client.GetOperatorClient(ctx, cfg.kubernetes.ControlPlaneAddress, security.TLSServerName, cfg.certChain)
client, _, err := client.GetOperatorClient(ctx, cfg.kubernetes.ControlPlaneAddress, sec)
if err != nil {
return nil, fmt.Errorf("error creating operator client: %w", err)
}
@ -403,13 +405,9 @@ func (a *DaprRuntime) setupTracing(ctx context.Context, hostAddress string, tpSt
func (a *DaprRuntime) initRuntime(ctx context.Context) error {
a.namespace = getNamespace()
err := a.establishSecurity(a.runtimeConfig.sentryServiceAddress)
if err != nil {
return err
}
a.podName = getPodName()
var err error
if a.hostAddress, err = utils.GetHostAddress(); err != nil {
return fmt.Errorf("failed to determine host address: %w", err)
}
@ -573,7 +571,7 @@ func (a *DaprRuntime) appHealthReadyInit(ctx context.Context) error {
var err error
// Load app configuration (for actors) and init actors
a.loadAppConfiguration()
a.loadAppConfiguration(ctx)
if len(a.runtimeConfig.placementAddresses) != 0 {
err = a.initActors(ctx)
@ -1022,7 +1020,7 @@ func (a *DaprRuntime) startHTTPServer(port int, publicPort *int, profilePort int
func (a *DaprRuntime) startGRPCInternalServer(api grpc.API, port int) error {
// Since GRPCInteralServer is encrypted & authenticated, it is safe to listen on *
serverConf := a.getNewServerConfig([]string{""}, port)
server := grpc.NewInternalServer(api, serverConf, a.globalConfig.GetTracingSpec(), a.globalConfig.GetMetricsSpec(), a.authenticator, a.proxy)
server := grpc.NewInternalServer(api, serverConf, a.globalConfig.GetTracingSpec(), a.globalConfig.GetMetricsSpec(), a.sec, a.proxy)
if err := server.StartNonBlocking(); err != nil {
return err
}
@ -1158,13 +1156,13 @@ func (a *DaprRuntime) initActors(ctx context.Context) error {
AppChannel: a.channels.AppChannel(),
GRPCConnectionFn: a.grpc.GetGRPCConnection,
Config: actorConfig,
CertChain: a.runtimeConfig.certChain,
TracingSpec: a.globalConfig.GetTracingSpec(),
Resiliency: a.resiliency,
StateStoreName: actorStateStoreName,
CompStore: a.compStore,
// TODO: @joshvanl Remove in Dapr 1.12 when ActorStateTTL is finalized.
StateTTLEnabled: a.globalConfig.IsFeatureEnabled(config.ActorStateTTL),
Security: a.sec,
})
err = act.Init(ctx)
if err == nil {
@ -1613,12 +1611,12 @@ func (a *DaprRuntime) blockUntilAppIsReady(ctx context.Context) error {
return nil
}
func (a *DaprRuntime) loadAppConfiguration() {
func (a *DaprRuntime) loadAppConfiguration(ctx context.Context) {
if a.channels.AppChannel() == nil {
return
}
appConfig, err := a.channels.AppChannel().GetAppConfig(a.runtimeConfig.id)
appConfig, err := a.channels.AppChannel().GetAppConfig(ctx, a.runtimeConfig.id)
if err != nil {
return
}
@ -1711,34 +1709,11 @@ func featureTypeToString(features interface{}) []string {
return featureStr
}
func (a *DaprRuntime) establishSecurity(sentryAddress string) error {
if !a.runtimeConfig.mTLSEnabled {
log.Info("mTLS is disabled. Skipping certificate request and tls validation")
return nil
}
if sentryAddress == "" {
return errors.New("sentryAddress cannot be empty")
}
log.Info("mTLS enabled; Creating sidecar authenticator")
auth, err := security.GetSidecarAuthenticator(sentryAddress, a.runtimeConfig.certChain)
if err != nil {
return err
}
a.authenticator = auth
a.grpc.SetAuthenticator(auth)
log.Info("Authenticator created")
diag.DefaultMonitoring.MTLSInitCompleted()
return nil
}
func componentDependency(compCategory components.Category, name string) string {
return fmt.Sprintf("%s:%s", compCategory, name)
}
func createGRPCManager(runtimeConfig *internalConfig, globalConfig *config.Configuration) *manager.Manager {
func createGRPCManager(sec security.Handler, runtimeConfig *internalConfig, globalConfig *config.Configuration) *manager.Manager {
grpcAppChannelConfig := &manager.AppChannelConfig{}
if globalConfig != nil {
grpcAppChannelConfig.TracingSpec = globalConfig.GetTracingSpec()
@ -1753,7 +1728,7 @@ func createGRPCManager(runtimeConfig *internalConfig, globalConfig *config.Confi
grpcAppChannelConfig.BaseAddress = runtimeConfig.appConnectionConfig.ChannelAddress
}
m := manager.NewManager(runtimeConfig.mode, grpcAppChannelConfig)
m := manager.NewManager(sec, runtimeConfig.mode, grpcAppChannelConfig)
m.StartCollector()
return m
}

View File

@ -17,6 +17,7 @@ package runtime
import (
"context"
"crypto/rand"
"crypto/x509"
"encoding/hex"
"encoding/json"
"errors"
@ -71,6 +72,7 @@ import (
secretstoresLoader "github.com/dapr/dapr/pkg/components/secretstores"
"github.com/dapr/dapr/pkg/config/protocol"
"github.com/dapr/dapr/pkg/metrics"
"github.com/dapr/dapr/pkg/security"
stateLoader "github.com/dapr/dapr/pkg/components/state"
"github.com/dapr/dapr/pkg/config"
@ -87,7 +89,6 @@ import (
"github.com/dapr/dapr/pkg/runtime/processor"
runtimePubsub "github.com/dapr/dapr/pkg/runtime/pubsub"
"github.com/dapr/dapr/pkg/runtime/registry"
"github.com/dapr/dapr/pkg/runtime/security"
securityConsts "github.com/dapr/dapr/pkg/security/consts"
daprt "github.com/dapr/dapr/pkg/testing"
"github.com/dapr/kit/logger"
@ -102,21 +103,9 @@ const (
maxGRPCServerUptime = 200 * time.Millisecond
)
var testCertRoot = `-----BEGIN CERTIFICATE-----
MIIBjjCCATOgAwIBAgIQdZeGNuAHZhXSmb37Pnx2QzAKBggqhkjOPQQDAjAYMRYw
FAYDVQQDEw1jbHVzdGVyLmxvY2FsMB4XDTIwMDIwMTAwMzUzNFoXDTMwMDEyOTAw
MzUzNFowGDEWMBQGA1UEAxMNY2x1c3Rlci5sb2NhbDBZMBMGByqGSM49AgEGCCqG
SM49AwEHA0IABAeMFRst4JhcFpebfgEs1MvJdD7h5QkCbLwChRHVEUoaDqd1aYjm
bX5SuNBXz5TBEhHfTV3Objh6LQ2N+CBoCeOjXzBdMA4GA1UdDwEB/wQEAwIBBjAS
BgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBRBWthv5ZQ3vALl2zXWwAXSmZ+m
qTAYBgNVHREEETAPgg1jbHVzdGVyLmxvY2FsMAoGCCqGSM49BAMCA0kAMEYCIQDN
rQNOck4ENOhmLROE/wqH0MKGjE6P8yzesgnp9fQI3AIhAJaVPrZloxl1dWCgmNWo
Iklq0JnMgJU7nS+VpVvlgBN8
-----END CERTIFICATE-----`
func TestNewRuntime(t *testing.T) {
// act
r, err := newDaprRuntime(context.Background(), &internalConfig{
r, err := newDaprRuntime(context.Background(), nil, &internalConfig{
metricsExporter: metrics.NewExporter(log, metrics.DefaultMetricNamespace),
registry: registry.New(registry.NewOptions()),
}, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
@ -127,7 +116,7 @@ func TestNewRuntime(t *testing.T) {
}
func TestProcessComponentsAndDependents(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
@ -150,7 +139,7 @@ func TestProcessComponentsAndDependents(t *testing.T) {
}
func TestDoProcessComponent(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
@ -526,7 +515,7 @@ func (cs *mockOperatorHTTPEndpointUpdateClientStream) Recv() (*operatorv1pb.HTTP
}
func TestComponentsUpdate(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.KubernetesMode)
rt, err := NewTestDaprRuntime(t, modes.KubernetesMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
@ -641,7 +630,7 @@ func TestComponentsUpdate(t *testing.T) {
// Test that flushOutstandingComponents waits for components.
func TestFlushOutstandingComponent(t *testing.T) {
t.Run("We can call flushOustandingComponents more than once", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
wasCalled := false
@ -692,7 +681,7 @@ func TestFlushOutstandingComponent(t *testing.T) {
assert.True(t, wasCalled)
})
t.Run("flushOutstandingComponents blocks for components with outstanding dependanices", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
wasCalled := false
@ -793,7 +782,7 @@ func TestFlushOutstandingComponent(t *testing.T) {
func TestInitSecretStores(t *testing.T) {
t.Run("init with store", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
m := NewMockKubernetesStore()
@ -817,7 +806,7 @@ func TestInitSecretStores(t *testing.T) {
})
t.Run("secret store is registered", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
m := NewMockKubernetesStore()
@ -844,7 +833,7 @@ func TestInitSecretStores(t *testing.T) {
})
t.Run("get secret store", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
m := NewMockKubernetesStore()
@ -900,7 +889,7 @@ func TestInitNameResolution(t *testing.T) {
t.Run("error on unknown resolver", func(t *testing.T) {
// given
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
// target resolver
@ -920,7 +909,7 @@ func TestInitNameResolution(t *testing.T) {
t.Run("test init nameresolution", func(t *testing.T) {
// given
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
// target resolver
@ -940,7 +929,7 @@ func TestInitNameResolution(t *testing.T) {
t.Run("test init nameresolution default in StandaloneMode", func(t *testing.T) {
// given
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
// target resolver
@ -958,7 +947,7 @@ func TestInitNameResolution(t *testing.T) {
t.Run("test init nameresolution nil in StandaloneMode", func(t *testing.T) {
// given
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
// target resolver
@ -976,7 +965,7 @@ func TestInitNameResolution(t *testing.T) {
t.Run("test init nameresolution default in KubernetesMode", func(t *testing.T) {
// given
rt, err := NewTestDaprRuntime(modes.KubernetesMode)
rt, err := NewTestDaprRuntime(t, modes.KubernetesMode)
require.NoError(t, err)
// target resolver
@ -994,7 +983,7 @@ func TestInitNameResolution(t *testing.T) {
t.Run("test init nameresolution nil in KubernetesMode", func(t *testing.T) {
// given
rt, err := NewTestDaprRuntime(modes.KubernetesMode)
rt, err := NewTestDaprRuntime(t, modes.KubernetesMode)
require.NoError(t, err)
// target resolver
@ -1087,7 +1076,7 @@ func TestSetupTracing(t *testing.T) {
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
rt.globalConfig.Spec.TracingSpec = &tc.tracingConfig
@ -1149,7 +1138,7 @@ func TestMetadataUUID(t *testing.T) {
},
},
})
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
mockPubSub := new(daprt.MockPubSub)
@ -1209,7 +1198,7 @@ func TestMetadataPodName(t *testing.T) {
},
},
})
rt, _ := NewTestDaprRuntime(modes.KubernetesMode)
rt, _ := NewTestDaprRuntime(t, modes.KubernetesMode)
defer stopRuntime(t, rt)
mockPubSub := new(daprt.MockPubSub)
@ -1255,7 +1244,7 @@ func TestMetadataNamespace(t *testing.T) {
},
},
})
rt, _ := NewTestDaprRuntimeWithID(modes.KubernetesMode, "app1")
rt, _ := NewTestDaprRuntimeWithID(t, modes.KubernetesMode, "app1")
defer stopRuntime(t, rt)
mockPubSub := new(daprt.MockPubSub)
@ -1304,7 +1293,7 @@ func TestMetadataClientID(t *testing.T) {
},
})
rt, err := NewTestDaprRuntimeWithID(modes.KubernetesMode, "myApp")
rt, err := NewTestDaprRuntimeWithID(t, modes.KubernetesMode, "myApp")
require.NoError(t, err)
rt.runtimeConfig.id = daprt.TestRuntimeConfigID
@ -1350,7 +1339,7 @@ func TestMetadataClientID(t *testing.T) {
},
})
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
rt.runtimeConfig.id = daprt.TestRuntimeConfigID
@ -1391,7 +1380,7 @@ func TestMetadataClientID(t *testing.T) {
func TestOnComponentUpdated(t *testing.T) {
t.Run("component spec changed, component is updated", func(t *testing.T) {
rt, _ := NewTestDaprRuntime(modes.KubernetesMode)
rt, _ := NewTestDaprRuntime(t, modes.KubernetesMode)
rt.compStore.AddComponent(componentsV1alpha1.Component{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
@ -1440,7 +1429,7 @@ func TestOnComponentUpdated(t *testing.T) {
})
t.Run("component spec unchanged, component is skipped", func(t *testing.T) {
rt, _ := NewTestDaprRuntime(modes.KubernetesMode)
rt, _ := NewTestDaprRuntime(t, modes.KubernetesMode)
rt.compStore.AddComponent(componentsV1alpha1.Component{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
@ -1492,7 +1481,7 @@ func TestOnComponentUpdated(t *testing.T) {
func TestPopulateSecretsConfiguration(t *testing.T) {
t.Run("secret store configuration is populated", func(t *testing.T) {
// setup
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
rt.globalConfig.Spec.Secrets = &config.SecretsSpec{
@ -1541,7 +1530,7 @@ func TestProcessResourceSecrets(t *testing.T) {
})
mockBinding.Auth.SecretStore = secretstoresLoader.BuiltinKubernetesSecretStore
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
m := NewMockKubernetesStore()
@ -1579,7 +1568,7 @@ func TestProcessResourceSecrets(t *testing.T) {
})
mockBinding.Auth.SecretStore = "mock"
rt, _ := NewTestDaprRuntime(modes.KubernetesMode)
rt, _ := NewTestDaprRuntime(t, modes.KubernetesMode)
defer stopRuntime(t, rt)
rt.runtimeConfig.registry.SecretStores().RegisterComponent(
@ -1616,7 +1605,7 @@ func TestProcessResourceSecrets(t *testing.T) {
EnvRef: "MY_ENV_VAR",
})
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
@ -1642,7 +1631,7 @@ func TestProcessResourceSecrets(t *testing.T) {
},
)
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
@ -1657,7 +1646,7 @@ func TestProcessResourceSecrets(t *testing.T) {
// Test InitSecretStore if secretstore.* refers to Kubernetes secret store.
func TestInitSecretStoresInKubernetesMode(t *testing.T) {
t.Run("built-in secret store is added", func(t *testing.T) {
rt, _ := NewTestDaprRuntime(modes.KubernetesMode)
rt, _ := NewTestDaprRuntime(t, modes.KubernetesMode)
m := NewMockKubernetesStore()
rt.runtimeConfig.registry.SecretStores().RegisterComponent(
@ -1671,7 +1660,7 @@ func TestInitSecretStoresInKubernetesMode(t *testing.T) {
})
t.Run("disable built-in secret store flag", func(t *testing.T) {
rt, _ := NewTestDaprRuntime(modes.KubernetesMode)
rt, _ := NewTestDaprRuntime(t, modes.KubernetesMode)
defer stopRuntime(t, rt)
rt.runtimeConfig.disableBuiltinK8sSecretStore = true
@ -1691,7 +1680,7 @@ func TestInitSecretStoresInKubernetesMode(t *testing.T) {
})
t.Run("built-in secret store bypasses authorizers", func(t *testing.T) {
rt, _ := NewTestDaprRuntime(modes.KubernetesMode)
rt, _ := NewTestDaprRuntime(t, modes.KubernetesMode)
rt.componentAuthorizers = []ComponentAuthorizer{
func(component componentsV1alpha1.Component) bool {
return false
@ -1727,14 +1716,14 @@ func assertBuiltInSecretStore(t *testing.T, rt *DaprRuntime) {
assert.NoError(t, rt.runnerCloser.Close())
}
func NewTestDaprRuntime(mode modes.DaprMode) (*DaprRuntime, error) {
return NewTestDaprRuntimeWithProtocol(mode, string(protocol.HTTPProtocol), 1024)
func NewTestDaprRuntime(t *testing.T, mode modes.DaprMode) (*DaprRuntime, error) {
return NewTestDaprRuntimeWithProtocol(t, mode, string(protocol.HTTPProtocol), 1024)
}
func NewTestDaprRuntimeWithID(mode modes.DaprMode, id string) (*DaprRuntime, error) {
testRuntimeConfig := NewTestDaprRuntimeConfig(modes.StandaloneMode, string(protocol.HTTPProtocol), 1024)
func NewTestDaprRuntimeWithID(t *testing.T, mode modes.DaprMode, id string) (*DaprRuntime, error) {
testRuntimeConfig := NewTestDaprRuntimeConfig(t, modes.StandaloneMode, string(protocol.HTTPProtocol), 1024)
testRuntimeConfig.id = id
rt, err := newDaprRuntime(context.Background(), testRuntimeConfig, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
rt, err := newDaprRuntime(context.Background(), testSecurity(t), testRuntimeConfig, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
if err != nil {
return nil, err
}
@ -1743,9 +1732,9 @@ func NewTestDaprRuntimeWithID(mode modes.DaprMode, id string) (*DaprRuntime, err
return rt, nil
}
func NewTestDaprRuntimeWithProtocol(mode modes.DaprMode, protocol string, appPort int) (*DaprRuntime, error) {
testRuntimeConfig := NewTestDaprRuntimeConfig(modes.StandaloneMode, protocol, appPort)
rt, err := newDaprRuntime(context.Background(), testRuntimeConfig, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
func NewTestDaprRuntimeWithProtocol(t *testing.T, mode modes.DaprMode, protocol string, appPort int) (*DaprRuntime, error) {
testRuntimeConfig := NewTestDaprRuntimeConfig(t, modes.StandaloneMode, protocol, appPort)
rt, err := newDaprRuntime(context.Background(), testSecurity(t), testRuntimeConfig, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
if err != nil {
return nil, err
}
@ -1754,7 +1743,7 @@ func NewTestDaprRuntimeWithProtocol(mode modes.DaprMode, protocol string, appPor
return rt, nil
}
func NewTestDaprRuntimeConfig(mode modes.DaprMode, appProtocol string, appPort int) *internalConfig {
func NewTestDaprRuntimeConfig(t *testing.T, mode modes.DaprMode, appProtocol string, appPort int) *internalConfig {
return &internalConfig{
id: daprt.TestRuntimeConfigID,
placementAddresses: []string{"10.10.10.12"},
@ -1798,51 +1787,11 @@ func NewTestDaprRuntimeConfig(mode modes.DaprMode, appProtocol string, appPort i
}
func TestGracefulShutdown(t *testing.T) {
r, err := NewTestDaprRuntime(modes.StandaloneMode)
r, err := NewTestDaprRuntime(t, modes.StandaloneMode)
assert.NoError(t, err)
assert.Equal(t, time.Second, r.runtimeConfig.gracefulShutdownDuration)
}
func TestMTLS(t *testing.T) {
t.Run("with mTLS enabled", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
rt.runtimeConfig.mTLSEnabled = true
rt.runtimeConfig.sentryServiceAddress = "1.1.1.1"
t.Setenv(securityConsts.TrustAnchorsEnvVar, testCertRoot)
t.Setenv(securityConsts.CertChainEnvVar, "a")
t.Setenv(securityConsts.CertKeyEnvVar, "b")
certChain, err := security.GetCertChain()
assert.NoError(t, err)
rt.runtimeConfig.certChain = certChain
err = rt.establishSecurity(rt.runtimeConfig.sentryServiceAddress)
assert.NoError(t, err)
assert.NotNil(t, rt.authenticator)
})
t.Run("with mTLS disabled", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
err = rt.establishSecurity(rt.runtimeConfig.sentryServiceAddress)
assert.NoError(t, err)
assert.Nil(t, rt.authenticator)
})
t.Run("mTLS disabled, operator fails without TLS certs", func(t *testing.T) {
rt, _ := NewTestDaprRuntime(modes.KubernetesMode)
defer stopRuntime(t, rt)
_, err := getOperatorClient(context.Background(), rt.runtimeConfig)
assert.Error(t, err)
})
}
func TestNamespace(t *testing.T) {
t.Run("empty namespace", func(t *testing.T) {
assert.Empty(t, getNamespace())
@ -1869,7 +1818,7 @@ func TestAuthorizedComponents(t *testing.T) {
testCompName := "fakeComponent"
t.Run("standalone mode, no namespce", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
component := componentsV1alpha1.Component{}
@ -1883,7 +1832,7 @@ func TestAuthorizedComponents(t *testing.T) {
})
t.Run("namespace mismatch", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
rt.namespace = "a"
@ -1899,7 +1848,7 @@ func TestAuthorizedComponents(t *testing.T) {
})
t.Run("namespace match", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
rt.namespace = "a"
@ -1915,7 +1864,7 @@ func TestAuthorizedComponents(t *testing.T) {
})
t.Run("in scope, namespace match", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
rt.namespace = "a"
@ -1932,7 +1881,7 @@ func TestAuthorizedComponents(t *testing.T) {
})
t.Run("not in scope, namespace match", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
rt.namespace = "a"
@ -1949,7 +1898,7 @@ func TestAuthorizedComponents(t *testing.T) {
})
t.Run("in scope, namespace mismatch", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
rt.namespace = "a"
@ -1966,7 +1915,7 @@ func TestAuthorizedComponents(t *testing.T) {
})
t.Run("not in scope, namespace mismatch", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
rt.namespace = "a"
@ -1983,7 +1932,7 @@ func TestAuthorizedComponents(t *testing.T) {
})
t.Run("no authorizers", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
rt.componentAuthorizers = []ComponentAuthorizer{}
@ -2002,7 +1951,7 @@ func TestAuthorizedComponents(t *testing.T) {
})
t.Run("only deny all", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
rt.componentAuthorizers = []ComponentAuthorizer{
@ -2021,8 +1970,8 @@ func TestAuthorizedComponents(t *testing.T) {
})
t.Run("additional authorizer denies all", func(t *testing.T) {
cfg := NewTestDaprRuntimeConfig(modes.StandaloneMode, string(protocol.HTTPSProtocol), 1024)
rt, err := newDaprRuntime(context.Background(), cfg, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
cfg := NewTestDaprRuntimeConfig(t, modes.StandaloneMode, string(protocol.HTTPSProtocol), 1024)
rt, err := newDaprRuntime(context.Background(), nil, cfg, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
require.NoError(t, err)
rt.componentAuthorizers = append(rt.componentAuthorizers, func(component componentsV1alpha1.Component) bool {
return false
@ -2040,7 +1989,7 @@ func TestAuthorizedComponents(t *testing.T) {
}
func TestAuthorizedHTTPEndpoints(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
endpoint := createTestEndpoint("testEndpoint", "http://api.test.com")
@ -2133,7 +2082,7 @@ func TestAuthorizedHTTPEndpoints(t *testing.T) {
func TestInitActors(t *testing.T) {
t.Run("missing namespace on kubernetes", func(t *testing.T) {
r, err := NewTestDaprRuntime(modes.KubernetesMode)
r, err := NewTestDaprRuntime(t, modes.KubernetesMode)
assert.NoError(t, err)
defer stopRuntime(t, r)
r.namespace = ""
@ -2144,7 +2093,7 @@ func TestInitActors(t *testing.T) {
})
t.Run("actors hosted = true", func(t *testing.T) {
r, err := NewTestDaprRuntime(modes.KubernetesMode)
r, err := NewTestDaprRuntime(t, modes.KubernetesMode)
require.NoError(t, err)
defer stopRuntime(t, r)
r.appConfig = config.ApplicationConfig{
@ -2156,7 +2105,7 @@ func TestInitActors(t *testing.T) {
})
t.Run("actors hosted = false", func(t *testing.T) {
r, err := NewTestDaprRuntime(modes.KubernetesMode)
r, err := NewTestDaprRuntime(t, modes.KubernetesMode)
require.NoError(t, err)
defer stopRuntime(t, r)
@ -2165,7 +2114,7 @@ func TestInitActors(t *testing.T) {
})
t.Run("placement enable = false", func(t *testing.T) {
r, err := newDaprRuntime(context.Background(), &internalConfig{
r, err := newDaprRuntime(context.Background(), testSecurity(t), &internalConfig{
metricsExporter: metrics.NewExporter(log, metrics.DefaultMetricNamespace),
registry: registry.New(registry.NewOptions()),
}, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
@ -2178,7 +2127,7 @@ func TestInitActors(t *testing.T) {
})
t.Run("the state stores can still be initialized normally", func(t *testing.T) {
r, err := newDaprRuntime(context.Background(), &internalConfig{
r, err := newDaprRuntime(context.Background(), testSecurity(t), &internalConfig{
metricsExporter: metrics.NewExporter(log, metrics.DefaultMetricNamespace),
registry: registry.New(registry.NewOptions()),
}, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
@ -2192,7 +2141,7 @@ func TestInitActors(t *testing.T) {
})
t.Run("the actor store can not be initialized normally", func(t *testing.T) {
r, err := newDaprRuntime(context.Background(), &internalConfig{
r, err := newDaprRuntime(context.Background(), testSecurity(t), &internalConfig{
metricsExporter: metrics.NewExporter(log, metrics.DefaultMetricNamespace),
registry: registry.New(registry.NewOptions()),
}, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
@ -2269,7 +2218,7 @@ func TestActorReentrancyConfig(t *testing.T) {
for _, tc := range testcases {
t.Run(tc.Name, func(t *testing.T) {
r, err := NewTestDaprRuntime(modes.StandaloneMode)
r, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
mockAppChannel := new(channelt.MockAppChannel)
@ -2281,7 +2230,7 @@ func TestActorReentrancyConfig(t *testing.T) {
mockAppChannel.On("GetAppConfig").Return(&configResp, nil)
r.loadAppConfiguration()
r.loadAppConfiguration(context.Background())
assert.NotNil(t, r.appConfig)
@ -2318,7 +2267,7 @@ func (s *mockStateStore) Close() error {
}
func TestCloseWithErrors(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
testErr := errors.New("mock close error")
@ -2455,8 +2404,8 @@ func TestComponentsCallback(t *testing.T) {
c := make(chan struct{})
callbackInvoked := false
cfg := NewTestDaprRuntimeConfig(modes.StandaloneMode, "http", port)
rt, err := newDaprRuntime(context.Background(), cfg, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
cfg := NewTestDaprRuntimeConfig(t, modes.StandaloneMode, "http", port)
rt, err := newDaprRuntime(context.Background(), testSecurity(t), cfg, &config.Configuration{}, &config.AccessControlList{}, resiliency.New(logger.NewLogger("test")))
require.NoError(t, err)
rt.runtimeConfig.registry = registry.New(registry.NewOptions().WithComponentsCallback(func(components registry.ComponentRegistry) error {
close(c)
@ -2495,7 +2444,7 @@ func TestGRPCProxy(t *testing.T) {
defer teardown()
// setup proxy
rt, err := NewTestDaprRuntimeWithProtocol(modes.StandaloneMode, "grpc", serverPort)
rt, err := NewTestDaprRuntimeWithProtocol(t, modes.StandaloneMode, "grpc", serverPort)
require.NoError(t, err)
internalPort, _ := freeport.GetFreePort()
rt.runtimeConfig.internalGRPCPort = internalPort
@ -2571,7 +2520,7 @@ func TestGRPCProxy(t *testing.T) {
func TestShutdownWithWait(t *testing.T) {
t.Run("calling ShutdownWithWait should wait until runtime has stopped", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
closeSecretClose := make(chan struct{})
@ -2651,7 +2600,7 @@ spec:
})
t.Run("if secret times out after init, error should return from runtime and ShutdownWithWait should return", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
initSecretContextClosed := make(chan struct{})
@ -2729,7 +2678,7 @@ spec:
})
t.Run("if secret init fails then the runtime should not error when the error should be ignored. Should wait for shutdown signal", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
secretInited := make(chan struct{})
@ -2802,7 +2751,7 @@ spec:
}
})
t.Run("if secret init fails then the runtime should error when the error should NOT be ignored. Shouldn't wait for shutdown signal", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
m := NewMockKubernetesStoreWithInitCallback(func(ctx context.Context) error {
@ -2869,7 +2818,7 @@ spec:
})
t.Run("runtime should fatal if closing components does not happen in time", func(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
m := NewMockKubernetesStoreWithClose(func() error {
@ -2924,7 +2873,7 @@ spec:
}
func TestGetComponentsCapabilitiesMap(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
defer stopRuntime(t, rt)
@ -3071,7 +3020,7 @@ func matchDaprRequestMethod(method string) any {
}
func TestGracefulShutdownBindings(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
@ -3115,7 +3064,7 @@ func TestGracefulShutdownBindings(t *testing.T) {
}
func TestGracefulShutdownPubSub(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
mockPubSub := new(daprt.MockPubSub)
rt.runtimeConfig.registry.PubSubs().RegisterComponent(
@ -3184,7 +3133,7 @@ func TestGracefulShutdownPubSub(t *testing.T) {
}
func TestGracefulShutdownActors(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
rt.runtimeConfig.gracefulShutdownDuration = 5 * time.Second
@ -3286,7 +3235,7 @@ func initMockStateStoreForRuntime(rt *DaprRuntime, encryptKey string, e error) *
}
func TestTraceShutdown(t *testing.T) {
rt, err := NewTestDaprRuntime(modes.StandaloneMode)
rt, err := NewTestDaprRuntime(t, modes.StandaloneMode)
require.NoError(t, err)
rt.runtimeConfig.gracefulShutdownDuration = 5 * time.Second
rt.globalConfig.Spec.TracingSpec = &config.TracingSpec{
@ -3331,7 +3280,7 @@ func createTestEndpoint(name, baseURL string) httpEndpointV1alpha1.HTTPEndpoint
}
func TestHTTPEndpointsUpdate(t *testing.T) {
rt, _ := NewTestDaprRuntime(modes.KubernetesMode)
rt, _ := NewTestDaprRuntime(t, modes.KubernetesMode)
defer stopRuntime(t, rt)
mockOpCli := newMockOperatorClient()
@ -3488,8 +3437,6 @@ func TestIsEnvVarAllowed(t *testing.T) {
{name: "keys starting with DAPR_ are denied", key: "DAPR_TEST", want: false},
{name: "APP_API_TOKEN is denied", key: "APP_API_TOKEN", want: false},
{name: "keys with a space are denied", key: "FOO BAR", want: false},
{name: "case insensitive app_api_token", key: "app_api_token", want: false},
{name: "case insensitive dapr_foo", key: "dapr_foo", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -3518,7 +3465,6 @@ func TestIsEnvVarAllowed(t *testing.T) {
{name: "keys starting with DAPR_ are denied", key: "DAPR_TEST", want: false},
{name: "APP_API_TOKEN is denied", key: "APP_API_TOKEN", want: false},
{name: "keys with a space are denied", key: "FOO BAR", want: false},
{name: "case insensitive allowlist", key: "foo", want: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -3529,3 +3475,22 @@ func TestIsEnvVarAllowed(t *testing.T) {
}
})
}
func testSecurity(t *testing.T) security.Handler {
secP, err := security.New(context.Background(), security.Options{
TrustAnchors: []byte("test"),
AppID: "test",
ControlPlaneTrustDomain: "test.example.com",
ControlPlaneNamespace: "default",
MTLSEnabled: false,
OverrideCertRequestSource: func(context.Context, []byte) ([]*x509.Certificate, error) {
return []*x509.Certificate{nil}, nil
},
})
require.NoError(t, err)
go secP.Run(context.Background())
sec, err := secP.Handler(context.Background())
require.NoError(t, err)
return sec
}

View File

@ -1,171 +0,0 @@
package security
import (
"context"
"crypto/x509"
"encoding/pem"
"fmt"
"os"
"sync"
"time"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpcRetry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
daprCredentials "github.com/dapr/dapr/pkg/credentials"
diag "github.com/dapr/dapr/pkg/diagnostics"
sentryv1pb "github.com/dapr/dapr/pkg/proto/sentry/v1"
securityConsts "github.com/dapr/dapr/pkg/security/consts"
securityToken "github.com/dapr/dapr/pkg/security/token"
)
const (
TLSServerName = "cluster.local"
sentrySignTimeout = time.Second * 5
sentryMaxRetries = 100
)
type Authenticator interface {
GetTrustAnchors() *x509.CertPool
GetCurrentSignedCert() *SignedCertificate
CreateSignedWorkloadCert(id, namespace, trustDomain string) (*SignedCertificate, error)
}
type authenticator struct {
trustAnchors *x509.CertPool
certChainPem []byte
keyPem []byte
genCSRFunc func(id string) ([]byte, []byte, error)
sentryAddress string
currentSignedCert *SignedCertificate
certMutex *sync.RWMutex
}
type SignedCertificate struct {
WorkloadCert []byte
PrivateKeyPem []byte
Expiry time.Time
TrustChain *x509.CertPool
}
func newAuthenticator(sentryAddress string, trustAnchors *x509.CertPool, certChainPem, keyPem []byte, genCSRFunc func(id string) ([]byte, []byte, error)) Authenticator {
return &authenticator{
trustAnchors: trustAnchors,
certChainPem: certChainPem,
keyPem: keyPem,
genCSRFunc: genCSRFunc,
sentryAddress: sentryAddress,
certMutex: &sync.RWMutex{},
}
}
// GetTrustAnchors returns the extracted root cert that serves as the trust anchor.
func (a *authenticator) GetTrustAnchors() *x509.CertPool {
return a.trustAnchors
}
// GetCurrentSignedCert returns the current and latest signed certificate.
func (a *authenticator) GetCurrentSignedCert() *SignedCertificate {
a.certMutex.RLock()
defer a.certMutex.RUnlock()
return a.currentSignedCert
}
// CreateSignedWorkloadCert returns a signed workload certificate, the PEM encoded private key
// And the duration of the signed cert.
func (a *authenticator) CreateSignedWorkloadCert(id, namespace, trustDomain string) (*SignedCertificate, error) {
csrb, pkPem, err := a.genCSRFunc(id)
if err != nil {
return nil, err
}
csrPem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrb})
config, err := daprCredentials.TLSConfigFromCertAndKey(a.certChainPem, a.keyPem, TLSServerName, a.trustAnchors)
if err != nil {
return nil, fmt.Errorf("failed to create tls config from cert and key: %w", err)
}
unaryClientInterceptor := grpcRetry.UnaryClientInterceptor()
if diag.DefaultGRPCMonitoring.IsEnabled() {
unaryClientInterceptor = grpcMiddleware.ChainUnaryClient(
unaryClientInterceptor,
diag.DefaultGRPCMonitoring.UnaryClientInterceptor(),
)
}
conn, err := grpc.Dial(
a.sentryAddress,
grpc.WithTransportCredentials(credentials.NewTLS(config)),
grpc.WithUnaryInterceptor(unaryClientInterceptor))
if err != nil {
diag.DefaultMonitoring.MTLSWorkLoadCertRotationFailed("sentry_conn")
return nil, fmt.Errorf("error establishing connection to sentry: %w", err)
}
defer conn.Close()
c := sentryv1pb.NewCAClient(conn)
token, tokenValidator, err := securityToken.GetSentryToken(true)
if err != nil {
return nil, fmt.Errorf("error obtaining token: %w", err)
}
resp, err := c.SignCertificate(context.Background(),
&sentryv1pb.SignCertificateRequest{
CertificateSigningRequest: csrPem,
Id: getSentryIdentifier(id),
Token: token,
TokenValidator: tokenValidator,
TrustDomain: trustDomain,
Namespace: namespace,
},
grpcRetry.WithMax(sentryMaxRetries),
grpcRetry.WithPerRetryTimeout(sentrySignTimeout),
)
if err != nil {
diag.DefaultMonitoring.MTLSWorkLoadCertRotationFailed("sign")
return nil, fmt.Errorf("error from sentry SignCertificate: %w", err)
}
workloadCert := resp.GetWorkloadCertificate()
validTimestamp := resp.GetValidUntil()
if err = validTimestamp.CheckValid(); err != nil {
diag.DefaultMonitoring.MTLSWorkLoadCertRotationFailed("invalid_ts")
return nil, fmt.Errorf("error parsing ValidUntil: %w", err)
}
expiry := validTimestamp.AsTime()
trustChain := x509.NewCertPool()
for _, c := range resp.GetTrustChainCertificates() {
ok := trustChain.AppendCertsFromPEM(c)
if !ok {
diag.DefaultMonitoring.MTLSWorkLoadCertRotationFailed("chaining")
return nil, fmt.Errorf("failed adding trust chain cert to x509 CertPool: %w", err)
}
}
signedCert := &SignedCertificate{
WorkloadCert: workloadCert,
PrivateKeyPem: pkPem,
Expiry: expiry,
TrustChain: trustChain,
}
a.certMutex.Lock()
defer a.certMutex.Unlock()
a.currentSignedCert = signedCert
return signedCert, nil
}
func getSentryIdentifier(appID string) string {
// Return the injected identity
// Defaults to app ID if not present
localID := os.Getenv(securityConsts.SentryLocalIdentityEnvVar)
if localID != "" {
return localID
}
return appID
}

View File

@ -1,46 +0,0 @@
package security
import (
"crypto/x509"
"testing"
"github.com/stretchr/testify/assert"
securityConsts "github.com/dapr/dapr/pkg/security/consts"
)
func mockGenCSR(id string) ([]byte, []byte, error) {
return []byte{1}, []byte{2}, nil
}
func getTestAuthenticator() Authenticator {
return newAuthenticator("test", x509.NewCertPool(), nil, nil, mockGenCSR)
}
func TestGetTrustAuthAnchors(t *testing.T) {
a := getTestAuthenticator()
ta := a.GetTrustAnchors()
assert.NotNil(t, ta)
}
func TestGetCurrentSignedCert(t *testing.T) {
a := getTestAuthenticator()
a.(*authenticator).currentSignedCert = &SignedCertificate{}
c := a.GetCurrentSignedCert()
assert.NotNil(t, c)
}
func TestGetSentryIdentifier(t *testing.T) {
t.Run("with identity in env", func(t *testing.T) {
envID := "cluster.local"
t.Setenv(securityConsts.SentryLocalIdentityEnvVar, envID)
id := getSentryIdentifier("app1")
assert.Equal(t, envID, id)
})
t.Run("without identity in env", func(t *testing.T) {
id := getSentryIdentifier("app1")
assert.Equal(t, "app1", id)
})
}

View File

@ -1,94 +0,0 @@
package security
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"os"
"github.com/dapr/dapr/pkg/credentials"
diag "github.com/dapr/dapr/pkg/diagnostics"
"github.com/dapr/dapr/pkg/security/consts"
"github.com/dapr/kit/logger"
)
const (
ecPKType = "EC PRIVATE KEY"
)
var log = logger.NewLogger("dapr.runtime.security")
func CertPool(certPem []byte) (*x509.CertPool, error) {
cp := x509.NewCertPool()
ok := cp.AppendCertsFromPEM(certPem)
if !ok {
return nil, errors.New("failed to append PEM root cert to x509 CertPool")
}
return cp, nil
}
func GetCertChain() (*credentials.CertChain, error) {
trustAnchors := os.Getenv(consts.TrustAnchorsEnvVar)
if trustAnchors == "" {
return nil, fmt.Errorf("couldn't find trust anchors in environment variable %s", consts.TrustAnchorsEnvVar)
}
cert := os.Getenv(consts.CertChainEnvVar)
if cert == "" {
return nil, fmt.Errorf("couldn't find cert chain in environment variable %s", consts.CertChainEnvVar)
}
key := os.Getenv(consts.CertKeyEnvVar)
if cert == "" {
return nil, fmt.Errorf("couldn't find cert key in environment variable %s", consts.CertKeyEnvVar)
}
return &credentials.CertChain{
RootCA: []byte(trustAnchors),
Cert: []byte(cert),
Key: []byte(key),
}, nil
}
// GetSidecarAuthenticator returns a new authenticator with the extracted trust anchors.
func GetSidecarAuthenticator(sentryAddress string, certChain *credentials.CertChain) (Authenticator, error) {
trustAnchors, err := CertPool(certChain.RootCA)
if err != nil {
return nil, err
}
log.Info("Trust anchors and cert chain extracted successfully")
return newAuthenticator(sentryAddress, trustAnchors, certChain.Cert, certChain.Key, generateCSRAndPrivateKey), nil
}
func generateCSRAndPrivateKey(id string) ([]byte, []byte, error) {
if id == "" {
return nil, nil, errors.New("id must not be empty")
}
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
diag.DefaultMonitoring.MTLSInitFailed("prikeygen")
return nil, nil, fmt.Errorf("failed to generate private key: %w", err)
}
encodedKey, err := x509.MarshalECPrivateKey(key)
if err != nil {
diag.DefaultMonitoring.MTLSInitFailed("prikeyenc")
return nil, nil, err
}
keyPem := pem.EncodeToMemory(&pem.Block{Type: ecPKType, Bytes: encodedKey})
csr := x509.CertificateRequest{
Subject: pkix.Name{CommonName: id},
DNSNames: []string{id},
}
csrb, err := x509.CreateCertificateRequest(rand.Reader, &csr, key)
if err != nil {
diag.DefaultMonitoring.MTLSInitFailed("csr")
return nil, nil, fmt.Errorf("failed to create sidecar csr: %w", err)
}
return csrb, keyPem, nil
}

View File

@ -1,76 +0,0 @@
package security
import (
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/dapr/dapr/pkg/security/consts"
)
var testRootCert = `-----BEGIN CERTIFICATE-----
MIIBjjCCATOgAwIBAgIQdZeGNuAHZhXSmb37Pnx2QzAKBggqhkjOPQQDAjAYMRYw
FAYDVQQDEw1jbHVzdGVyLmxvY2FsMB4XDTIwMDIwMTAwMzUzNFoXDTMwMDEyOTAw
MzUzNFowGDEWMBQGA1UEAxMNY2x1c3Rlci5sb2NhbDBZMBMGByqGSM49AgEGCCqG
SM49AwEHA0IABAeMFRst4JhcFpebfgEs1MvJdD7h5QkCbLwChRHVEUoaDqd1aYjm
bX5SuNBXz5TBEhHfTV3Objh6LQ2N+CBoCeOjXzBdMA4GA1UdDwEB/wQEAwIBBjAS
BgNVHRMBAf8ECDAGAQH/AgEBMB0GA1UdDgQWBBRBWthv5ZQ3vALl2zXWwAXSmZ+m
qTAYBgNVHREEETAPgg1jbHVzdGVyLmxvY2FsMAoGCCqGSM49BAMCA0kAMEYCIQDN
rQNOck4ENOhmLROE/wqH0MKGjE6P8yzesgnp9fQI3AIhAJaVPrZloxl1dWCgmNWo
Iklq0JnMgJU7nS+VpVvlgBN8
-----END CERTIFICATE-----`
func TestGetTrustAnchors(t *testing.T) {
t.Run("invalid root cert", func(t *testing.T) {
t.Setenv(consts.TrustAnchorsEnvVar, "111")
t.Setenv(consts.CertChainEnvVar, "111")
t.Setenv(consts.CertKeyEnvVar, "111")
certChain, _ := GetCertChain()
caPool, err := CertPool(certChain.Cert)
assert.Error(t, err)
assert.Nil(t, caPool)
})
t.Run("valid root cert", func(t *testing.T) {
t.Setenv(consts.TrustAnchorsEnvVar, testRootCert)
t.Setenv(consts.CertChainEnvVar, "111")
t.Setenv(consts.CertKeyEnvVar, "111")
certChain, err := GetCertChain()
assert.Nil(t, err)
caPool, err := CertPool(certChain.RootCA)
assert.Nil(t, err)
assert.NotNil(t, caPool)
})
}
func TestGenerateSidecarCSR(t *testing.T) {
// can't run this on Windows build agents, GH actions fails with "CryptAcquireContext: Provider DLL failed to initialize correctly."
if runtime.GOOS == "windows" {
return
}
t.Run("empty id", func(t *testing.T) {
_, _, err := generateCSRAndPrivateKey("")
assert.NotNil(t, err)
})
t.Run("with id", func(t *testing.T) {
csr, pk, err := generateCSRAndPrivateKey("test")
assert.Nil(t, err)
assert.True(t, len(csr) > 0)
assert.True(t, len(pk) > 0)
})
}
func TestInitSidecarAuthenticator(t *testing.T) {
t.Setenv(consts.TrustAnchorsEnvVar, testRootCert)
t.Setenv(consts.CertChainEnvVar, "111")
t.Setenv(consts.CertKeyEnvVar, "111")
certChain, _ := GetCertChain()
_, err := GetSidecarAuthenticator("localhost:5050", certChain)
assert.NoError(t, err)
}

View File

@ -13,13 +13,15 @@ limitations under the License.
package consts
/* #nosec */
const (
// APITokenEnvVar is the environment variable for the API token.
//nolint:gosec
APITokenEnvVar = "DAPR_API_TOKEN"
// AppAPITokenEnvVar is the environment variable for the app API token.
//nolint:gosec
AppAPITokenEnvVar = "APP_API_TOKEN"
// APITokenHeader is header name for HTTP/gRPC calls to hold the token.
//nolint:gosec
APITokenHeader = "dapr-api-token"
// TrustBundleK8sSecretName is the name of the kubernetes secret that holds the trust bundle.
@ -41,9 +43,18 @@ const (
// SentryLocalIdentityEnvVar is the environment variable for the local identity sent to Sentry.
SentryLocalIdentityEnvVar = "SENTRY_LOCAL_IDENTITY"
// SentryTokenFileEnvVar is the environment variable for the Sentry token file.
//nolint:gosec
SentryTokenFileEnvVar = "DAPR_SENTRY_TOKEN_FILE"
// AnnotationKeyControlPlane is the annotation to mark a control plane
// component. The value is the name of the control plane service.
AnnotationKeyControlPlane = "dapr.io/control-plane"
// ControlPlaneAddressEnvVar is the daprd environment variable for
// configuring the control plane namespace.
ControlPlaneNamespaceEnvVar = "DAPR_CONTROLPLANE_NAMESPACE"
// ControlPlaneAddressEnvVar is the daprd environment variable for
// configuring the control plane trust domain.
ControlPlaneTrustDomainEnvVar = "DAPR_CONTROLPLANE_TRUST_DOMAIN"
)

View File

@ -28,28 +28,49 @@ import (
)
type Fake struct {
grpcServerOptionFn func() grpc.ServerOption
grpcServerOptionNoClientAuthFn func() grpc.ServerOption
grpcDialOptionFn func(spiffeid.ID) grpc.DialOption
controlPlaneTrustDomainFn func() spiffeid.TrustDomain
controlPlaneNamespaceFn func() string
currentTrustAnchorsFn func() ([]byte, error)
watchTrustAnchorsFn func(context.Context, chan<- []byte)
tlsServerConfigNoClientAuth func() *tls.Config
netListenerIDFn func(net.Listener, spiffeid.ID) net.Listener
netDialerIDFn func(context.Context, spiffeid.ID, time.Duration) func(network, addr string) (net.Conn, error)
tlsServerConfigMTLSFn func(spiffeid.TrustDomain) (*tls.Config, error)
tlsServerConfigNoClientAuthFn func() *tls.Config
tlsServerConfigNoClientAuthOptionFn func(*tls.Config)
netListenerIDFn func(net.Listener, spiffeid.ID) net.Listener
netDialerIDFn func(context.Context, spiffeid.ID, time.Duration) func(network, addr string) (net.Conn, error)
currentTrustAnchorsFn func() ([]byte, error)
watchTrustAnchorsFn func(context.Context, chan<- []byte)
grpcDialOptionFn func(spiffeid.ID) grpc.DialOption
grpcDialOptionUnknownTrustDomainFn func(ns, appID string) grpc.DialOption
grpcServerOptionMTLSFn func() grpc.ServerOption
grpcServerOptionNoClientAuthFn func() grpc.ServerOption
}
func New() *Fake {
return &Fake{
grpcServerOptionFn: func() grpc.ServerOption {
return grpc.Creds(insecure.NewCredentials())
controlPlaneTrustDomainFn: func() spiffeid.TrustDomain {
return spiffeid.RequireTrustDomainFromString("example.org")
},
tlsServerConfigNoClientAuth: func() *tls.Config {
controlPlaneNamespaceFn: func() string {
return "dapr-test"
},
tlsServerConfigMTLSFn: func(spiffeid.TrustDomain) (*tls.Config, error) {
return new(tls.Config), nil
},
tlsServerConfigNoClientAuthFn: func() *tls.Config {
return new(tls.Config)
},
tlsServerConfigNoClientAuthOptionFn: func(*tls.Config) {},
grpcDialOptionFn: func(spiffeid.ID) grpc.DialOption {
return grpc.WithTransportCredentials(insecure.NewCredentials())
},
grpcDialOptionUnknownTrustDomainFn: func(ns, appID string) grpc.DialOption {
return grpc.WithTransportCredentials(insecure.NewCredentials())
},
grpcServerOptionMTLSFn: func() grpc.ServerOption {
return grpc.Creds(nil)
},
grpcServerOptionNoClientAuthFn: func() grpc.ServerOption {
return grpc.Creds(insecure.NewCredentials())
return grpc.Creds(nil)
},
currentTrustAnchorsFn: func() ([]byte, error) {
return []byte{}, nil
@ -57,9 +78,6 @@ func New() *Fake {
watchTrustAnchorsFn: func(context.Context, chan<- []byte) {
return
},
grpcDialOptionFn: func(id spiffeid.ID) grpc.DialOption {
return grpc.WithTransportCredentials(insecure.NewCredentials())
},
netListenerIDFn: func(l net.Listener, _ spiffeid.ID) net.Listener {
return l
},
@ -69,19 +87,81 @@ func New() *Fake {
}
}
func (f *Fake) WithControlPlaneTrustDomainFn(fn func() spiffeid.TrustDomain) *Fake {
f.controlPlaneTrustDomainFn = fn
return f
}
func (f *Fake) WithControlPlaneNamespaceFn(fn func() string) *Fake {
f.controlPlaneNamespaceFn = fn
return f
}
func (f *Fake) WithTLSServerConfigMTLSFn(fn func(spiffeid.TrustDomain) (*tls.Config, error)) *Fake {
f.tlsServerConfigMTLSFn = fn
return f
}
func (f *Fake) WithTLSServerConfigNoClientAuthFn(fn func() *tls.Config) *Fake {
f.tlsServerConfigNoClientAuthFn = fn
return f
}
func (f *Fake) WithTLSServerConfigNoClientAuthOptionFn(fn func(*tls.Config)) *Fake {
f.tlsServerConfigNoClientAuthOptionFn = fn
return f
}
func (f *Fake) WithGRPCDialOptionMTLSFn(fn func(spiffeid.ID) grpc.DialOption) *Fake {
f.grpcDialOptionFn = fn
return f
}
func (f *Fake) WithGRPCDialOptionMTLSUnknownTrustDomainFn(fn func(ns, appID string) grpc.DialOption) *Fake {
f.grpcDialOptionUnknownTrustDomainFn = fn
return f
}
func (f *Fake) WithGRPCServerOptionMTLSFn(fn func() grpc.ServerOption) *Fake {
f.grpcServerOptionMTLSFn = fn
return f
}
func (f *Fake) WithGRPCServerOptionNoClientAuthFn(fn func() grpc.ServerOption) *Fake {
f.grpcServerOptionNoClientAuthFn = fn
return f
}
func (f *Fake) WithGRPCServerOptionFn(fn func() grpc.ServerOption) *Fake {
f.grpcServerOptionFn = fn
return f
func (f *Fake) ControlPlaneTrustDomain() spiffeid.TrustDomain {
return f.controlPlaneTrustDomainFn()
}
func (f *Fake) WithTLSServerConfigNoClientAuthFn(fn func() *tls.Config) *Fake {
f.tlsServerConfigNoClientAuth = fn
return f
func (f *Fake) ControlPlaneNamespace() string {
return f.controlPlaneNamespaceFn()
}
func (f *Fake) TLSServerConfigMTLS(td spiffeid.TrustDomain) (*tls.Config, error) {
return f.tlsServerConfigMTLSFn(td)
}
func (f *Fake) TLSServerConfigNoClientAuth() *tls.Config {
return f.tlsServerConfigNoClientAuthFn()
}
func (f *Fake) TLSServerConfigNoClientAuthOption(cfg *tls.Config) {
f.tlsServerConfigNoClientAuthOptionFn(cfg)
}
func (f *Fake) GRPCDialOptionMTLS(id spiffeid.ID) grpc.DialOption {
return f.grpcDialOptionFn(id)
}
func (f *Fake) GRPCDialOptionMTLSUnknownTrustDomain(ns, appID string) grpc.DialOption {
return f.grpcDialOptionUnknownTrustDomainFn(ns, appID)
}
func (f *Fake) GRPCServerOptionMTLS() grpc.ServerOption {
return f.grpcServerOptionMTLSFn()
}
func (f *Fake) WithCurrentTrustAnchorsFn(fn func() ([]byte, error)) *Fake {
@ -113,14 +193,6 @@ func (f *Fake) GRPCServerOptionNoClientAuth() grpc.ServerOption {
return f.grpcServerOptionNoClientAuthFn()
}
func (f *Fake) GRPCServerOption() grpc.ServerOption {
return f.grpcServerOptionFn()
}
func (f *Fake) TLSServerConfigNoClientAuth() *tls.Config {
return f.tlsServerConfigNoClientAuth()
}
func (f *Fake) CurrentTrustAnchors() ([]byte, error) {
return f.currentTrustAnchorsFn()
}
@ -140,11 +212,3 @@ func (f *Fake) NetListenerID(l net.Listener, id spiffeid.ID) net.Listener {
func (f *Fake) NetDialerID(ctx context.Context, id spiffeid.ID, timeout time.Duration) func(network, addr string) (net.Conn, error) {
return f.netDialerIDFn(ctx, id, timeout)
}
func (f *Fake) ControlPlaneNamespace() string {
return "dapr-test"
}
func (f *Fake) ControlPlaneTrustDomain() spiffeid.TrustDomain {
return spiffeid.RequireTrustDomainFromString("example.com")
}

View File

@ -17,10 +17,14 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"os"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/dapr/dapr/pkg/security/consts"
)
// NewServer returns a `tls.Config` intended for network servers. Because
@ -31,37 +35,7 @@ import (
// TODO: @joshvanl: This package should be removed in v1.13.
func NewServer(svid x509svid.Source, bundle x509bundle.Source, authorizer tlsconfig.Authorizer) *tls.Config {
spiffeVerify := tlsconfig.VerifyPeerCertificate(bundle, authorizer)
dnsVerify := func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
var certs []*x509.Certificate
for _, rawCert := range rawCerts {
cert, err := x509.ParseCertificate(rawCert)
if err != nil {
return err
}
certs = append(certs, cert)
}
if len(certs) == 0 {
return errors.New("no certificates provided")
}
id, err := svid.GetX509SVID()
if err != nil {
return err
}
rootBundle, err := bundle.GetX509BundleForTrustDomain(id.ID.TrustDomain())
if err != nil {
return err
}
_, err = certs[0].Verify(x509.VerifyOptions{
DNSName: "cluster.local",
Intermediates: newCertPool(certs[1:]),
Roots: newCertPool(rootBundle.X509Authorities()),
})
return err
}
dnsVerify := dnsVerifyFn(svid, bundle)
return &tls.Config{
ClientAuth: tls.RequireAnyClientCert,
@ -82,6 +56,72 @@ func NewServer(svid x509svid.Source, bundle x509bundle.Source, authorizer tlscon
}
}
// NewDialClient returns a `tls.Config` intended for network clients. Because pre
// v1.12 Dapr servers will be using the issuing CA key pair (!!) for serving
// and client auth, we need to fallback the `VerifyPeerCertificate` method to
// match on `cluster.local` DNS if and when the SPIFFE mTLS handshake fails.
// TODO: @joshvanl: This package should be removed in v1.13.
func NewDialClient(svid x509svid.Source, bundle x509bundle.Source, authorizer tlsconfig.Authorizer) *tls.Config {
tlsConfig := newDialClientNoClientAuth(svid, bundle, authorizer)
tlsConfig.GetClientCertificate = tlsconfig.GetClientCertificate(svid)
return tlsConfig
}
// NewDialClientOptionalClientAuth returns a `tls.Config` intended for network
// clients with optional client authentication. Because pre v1.12 Dapr servers
// will be using the issuing CA key pair (!!) for serving and client auth, we
// need to fallback the `VerifyPeerCertificate` method to match on
// `cluster.local` DNS if and when the SPIFFE mTLS handshake fails.
// Sets the client certificate to that configured in environment variables to satisfy
// sentry v1.11 servers.
func NewDialClientOptionalClientAuth(svid x509svid.Source, bundle x509bundle.Source, authorizer tlsconfig.Authorizer) (*tls.Config, error) {
tlsConfig := newDialClientNoClientAuth(svid, bundle, authorizer)
certPEM, cok := os.LookupEnv(consts.CertChainEnvVar)
keyPEM, pok := os.LookupEnv(consts.CertKeyEnvVar)
if cok && pok {
cert, err := tls.X509KeyPair([]byte(certPEM), []byte(keyPEM))
if err != nil {
return nil, err
}
tlsConfig.GetClientCertificate = func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &cert, nil
}
}
return tlsConfig, nil
}
// newDialClientNoClientAuth returns a `tls.Config` intended for network clients
// without client authentication.
func newDialClientNoClientAuth(svid x509svid.Source, bundle x509bundle.Source, authorizer tlsconfig.Authorizer) *tls.Config {
spiffeVerify := tlsconfig.VerifyPeerCertificate(bundle, authorizer)
dnsVerify := dnsVerifyFn(svid, bundle)
return &tls.Config{
MinVersion: tls.VersionTLS12,
// Yep! We need to set this option because we are performing our own TLS
// handshake verification, namely the SPIFFE ID validation, and then
// falling back to the DNS verification `cluster.local`. See:
// https://pkg.go.dev/crypto/tls#Config
// This is not insecure (bar the poor DNS verification which is needed for
// backwards compatibility).
InsecureSkipVerify: true, //nolint:gosec
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
// If SPIFFE verification fails, also attempt `cluster.local` DNS
// verification.
sErr := spiffeVerify(rawCerts, nil)
if sErr != nil {
dErr := dnsVerify(rawCerts, nil)
if dErr != nil {
return errors.Join(sErr, dErr)
}
}
return nil
},
}
}
func newCertPool(certs []*x509.Certificate) *x509.CertPool {
pool := x509.NewCertPool()
for _, cert := range certs {
@ -89,3 +129,42 @@ func newCertPool(certs []*x509.Certificate) *x509.CertPool {
}
return pool
}
func dnsVerifyFn(svid x509svid.Source, bundle x509bundle.Source) func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
var certs []*x509.Certificate
for _, rawCert := range rawCerts {
cert, err := x509.ParseCertificate(rawCert)
if err != nil {
return err
}
certs = append(certs, cert)
}
if len(certs) == 0 {
return errors.New("no certificates provided")
}
id, err := svid.GetX509SVID()
if err != nil {
return err
}
// Default to empty trust domain if no SPIFFE ID is present.
td := spiffeid.TrustDomain{}
if id != nil {
td = id.ID.TrustDomain()
}
rootBundle, err := bundle.GetX509BundleForTrustDomain(td)
if err != nil {
return err
}
_, err = certs[0].Verify(x509.VerifyOptions{
DNSName: "cluster.local",
Intermediates: newCertPool(certs[1:]),
Roots: newCertPool(rootBundle.X509Authorities()),
})
return err
}
}

View File

@ -37,66 +37,66 @@ import (
"github.com/stretchr/testify/require"
)
func serialNumber(t *testing.T) *big.Int {
t.Helper()
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)
return serialNumber
}
func genIssuerCA(t *testing.T) (issuerCA *x509.Certificate, issuerKey *ecdsa.PrivateKey, rootCA x509bundle.Source, rootPool *x509.CertPool) {
t.Helper()
rootPK, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
tmpl := x509.Certificate{
SerialNumber: serialNumber(t),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
SignatureAlgorithm: x509.ECDSAWithSHA256,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &rootPK.PublicKey, rootPK)
require.NoError(t, err)
rootCert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
issPK, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
tmpl = x509.Certificate{
SerialNumber: serialNumber(t),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
SignatureAlgorithm: x509.ECDSAWithSHA256,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
issCertDER, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &issPK.PublicKey, rootPK)
require.NoError(t, err)
issCert, err := x509.ParseCertificate(issCertDER)
require.NoError(t, err)
rootPool = x509.NewCertPool()
rootPool.AddCert(rootCert)
return issCert, issPK,
x509bundle.FromX509Authorities(spiffeid.RequireTrustDomainFromString("example.com"), []*x509.Certificate{rootCert}),
rootPool
}
func Test_NewServer(t *testing.T) {
serialNumber := func(t *testing.T) *big.Int {
t.Helper()
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)
return serialNumber
}
genIssuerCA := func(t *testing.T) (issuerCA *x509.Certificate, issuerKey *ecdsa.PrivateKey, rootCA x509bundle.Source, rootPool *x509.CertPool) {
t.Helper()
rootPK, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
tmpl := x509.Certificate{
SerialNumber: serialNumber(t),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
SignatureAlgorithm: x509.ECDSAWithSHA256,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &rootPK.PublicKey, rootPK)
require.NoError(t, err)
rootCert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
issPK, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
tmpl = x509.Certificate{
SerialNumber: serialNumber(t),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
SignatureAlgorithm: x509.ECDSAWithSHA256,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
issCertDER, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &issPK.PublicKey, rootPK)
require.NoError(t, err)
issCert, err := x509.ParseCertificate(issCertDER)
require.NoError(t, err)
rootPool = x509.NewCertPool()
rootPool.AddCert(rootCert)
return issCert, issPK,
x509bundle.FromX509Authorities(spiffeid.RequireTrustDomainFromString("example.com"), []*x509.Certificate{rootCert}),
rootPool
}
issCert, issKey, rootCA, rootPool := genIssuerCA(t)
diffCert, diffKey, diffRootCA, diffPool := genIssuerCA(t)
@ -129,7 +129,7 @@ func Test_NewServer(t *testing.T) {
var lock sync.Mutex
server := &http.Server{
Addr: ":0",
Addr: "localhost:0",
TLSConfig: NewServer(serverSVID, rootCA, tlsconfig.AuthorizeAny()),
ReadHeaderTimeout: time.Second,
}
@ -293,3 +293,199 @@ func Test_NewServer(t *testing.T) {
assert.ErrorContains(t, err, "remote error: tls: bad certificate")
})
}
func Test_NewDialClient(t *testing.T) {
issCert, issKey, rootCA, rootPool := genIssuerCA(t)
diffCert, diffKey, diffRootCA, diffPool := genIssuerCA(t)
clientPK, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
clientPKDER, err := x509.MarshalPKCS8PrivateKey(clientPK)
require.NoError(t, err)
serverPK, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
serverPKDER, err := x509.MarshalPKCS8PrivateKey(serverPK)
require.NoError(t, err)
clientCertDER, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
URIs: []*url.URL{spiffeid.RequireFromSegments(spiffeid.RequireTrustDomainFromString("example.com"), "client").URL()},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute),
SerialNumber: serialNumber(t),
}, issCert, &clientPK.PublicKey, issKey)
require.NoError(t, err)
clientSVID, err := x509svid.ParseRaw(append(clientCertDER, issCert.Raw...), clientPKDER)
require.NoError(t, err)
serve := func(t *testing.T, serverConfig *tls.Config) error {
var lis net.Listener
var lock sync.Mutex
server := &http.Server{
Addr: "localhost:0",
TLSConfig: serverConfig,
ReadHeaderTimeout: time.Second,
}
server.BaseContext = func(nlis net.Listener) context.Context {
lock.Lock()
defer lock.Unlock()
lis = nlis
return context.Background()
}
serverClosed := make(chan struct{})
go func() {
defer close(serverClosed)
require.ErrorIs(t, server.ListenAndServeTLS("", ""), http.ErrServerClosed)
}()
t.Cleanup(func() {
require.NoError(t, server.Close())
select {
case <-serverClosed:
case <-time.After(time.Second):
t.Fatal("timed out waiting for server to close")
}
})
assert.Eventually(t, func() bool {
lock.Lock()
defer lock.Unlock()
return lis != nil
}, time.Second, time.Millisecond)
client := &http.Client{Transport: &http.Transport{TLSClientConfig: NewDialClient(clientSVID, rootCA, tlsconfig.AuthorizeAny())}}
conn, err := client.Get(fmt.Sprintf("https://localhost:%d/", lis.Addr().(*net.TCPAddr).Port))
if err != nil {
return err
}
conn.Body.Close()
return nil
}
t.Run("if server uses a SVID in the same root with the correct Trust Domain, no error", func(t *testing.T) {
id := spiffeid.RequireFromSegments(spiffeid.RequireTrustDomainFromString("example.com"), "server")
serverCertDER, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
URIs: []*url.URL{id.URL()},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute),
SerialNumber: serialNumber(t),
}, issCert, &serverPK.PublicKey, issKey)
require.NoError(t, err)
serversvid, err := x509svid.ParseRaw(append(serverCertDER, issCert.Raw...), serverPKDER)
require.NoError(t, err)
assert.NoError(t, serve(t, tlsconfig.MTLSServerConfig(serversvid, rootCA, tlsconfig.AuthorizeAny())))
})
t.Run("if server uses a SVID but signed by different root, expect error", func(t *testing.T) {
id := spiffeid.RequireFromSegments(spiffeid.RequireTrustDomainFromString("example.com"), "server")
serverCertDER, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
URIs: []*url.URL{id.URL()},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute),
SerialNumber: serialNumber(t),
}, diffCert, &serverPK.PublicKey, diffKey)
require.NoError(t, err)
serversvid, err := x509svid.ParseRaw(append(serverCertDER, diffCert.Raw...), serverPKDER)
require.NoError(t, err)
err = serve(t, tlsconfig.MTLSServerConfig(serversvid, diffRootCA, tlsconfig.AuthorizeAny()))
assert.ErrorContains(t, err, "x509: ECDSA verification failure")
})
t.Run("if server uses DNS and is `cluster.local`, expect no error", func(t *testing.T) {
serverCertDER, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
DNSNames: []string{"cluster.local"},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute),
SerialNumber: serialNumber(t),
}, issCert, &serverPK.PublicKey, issKey)
require.NoError(t, err)
assert.NoError(t, serve(t, &tls.Config{
RootCAs: rootPool,
InsecureSkipVerify: true, //nolint: gosec // this is a test
Certificates: []tls.Certificate{
{Certificate: [][]byte{serverCertDER, issCert.Raw}, PrivateKey: serverPK},
},
}))
})
t.Run("if server uses DNS and one is `cluster.local`, expect no error", func(t *testing.T) {
serverCertDER, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
DNSNames: []string{"no-cluster.foo.local", "cluster.local"},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute),
SerialNumber: serialNumber(t),
}, issCert, &serverPK.PublicKey, issKey)
require.NoError(t, err)
assert.NoError(t, serve(t, &tls.Config{
RootCAs: rootPool,
InsecureSkipVerify: true, //nolint: gosec // this is a test
Certificates: []tls.Certificate{
{Certificate: [][]byte{serverCertDER, issCert.Raw}, PrivateKey: serverPK},
},
}))
})
t.Run("if server uses DNS but none are `cluster.local`, expect error", func(t *testing.T) {
serverCertDER, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
DNSNames: []string{"no-cluster.foo.local", "local.cluster"},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute),
SerialNumber: serialNumber(t),
}, issCert, &serverPK.PublicKey, issKey)
require.NoError(t, err)
err = serve(t, &tls.Config{
RootCAs: rootPool,
InsecureSkipVerify: true, //nolint: gosec // this is a test
Certificates: []tls.Certificate{
{Certificate: [][]byte{serverCertDER, issCert.Raw}, PrivateKey: serverPK},
},
})
assert.ErrorContains(t, err, "x509svid: could not get leaf SPIFFE ID: certificate contains no URI SAN\nx509: certificate is valid for no-cluster.foo.local, local.cluster, not cluster.local")
})
t.Run("if server uses DNS but is from different root, expect error", func(t *testing.T) {
serverCertDER, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{
DNSNames: []string{"no-cluster.foo.local", "local.cluster"},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute),
SerialNumber: serialNumber(t),
}, diffCert, &serverPK.PublicKey, diffKey)
require.NoError(t, err)
err = serve(t, &tls.Config{
RootCAs: diffPool,
InsecureSkipVerify: true, //nolint: gosec // this is a test
Certificates: []tls.Certificate{
{Certificate: [][]byte{serverCertDER, diffCert.Raw}, PrivateKey: serverPK},
},
})
assert.ErrorContains(t, err, "x509svid: could not get leaf SPIFFE ID: certificate contains no URI SAN\nx509: certificate is valid for no-cluster.foo.local, local.cluster, not cluster.local")
})
}

View File

@ -44,25 +44,28 @@ var log = logger.NewLogger("dapr.runtime.security")
type RequestFn func(ctx context.Context, der []byte) ([]*x509.Certificate, error)
// Handler implements middleware for client and server connection security.
//
//nolint:interfacebloat
type Handler interface {
GRPCServerOption() grpc.ServerOption
GRPCServerOptionMTLS() grpc.ServerOption
GRPCServerOptionNoClientAuth() grpc.ServerOption
GRPCDialOption(spiffeid.ID) grpc.DialOption
GRPCDialOptionMTLSUnknownTrustDomain(ns, appID string) grpc.DialOption
GRPCDialOptionMTLS(spiffeid.ID) grpc.DialOption
TLSServerConfigNoClientAuth() *tls.Config
NetListenerID(net.Listener, spiffeid.ID) net.Listener
NetDialerID(context.Context, spiffeid.ID, time.Duration) func(network, addr string) (net.Conn, error)
CurrentTrustAnchors() ([]byte, error)
WatchTrustAnchors(context.Context, chan<- []byte)
ControlPlaneNamespace() string
ControlPlaneTrustDomain() spiffeid.TrustDomain
ControlPlaneNamespace() string
CurrentTrustAnchors() ([]byte, error)
WatchTrustAnchors(context.Context, chan<- []byte)
}
// Provider is the security provider.
type Provider interface {
Start(context.Context) error
Run(context.Context) error
Handler(context.Context) (Handler, error)
}
@ -133,7 +136,6 @@ func New(ctx context.Context, opts Options) (Provider, error) {
}
var source *x509source
if opts.MTLSEnabled {
if len(opts.TrustAnchors) > 0 && len(opts.TrustAnchorsFile) > 0 {
return nil, errors.New("trust anchors cannot be specified in both TrustAnchors and TrustAnchorsFile")
@ -164,9 +166,9 @@ func New(ctx context.Context, opts Options) (Provider, error) {
}, nil
}
// Start is a blocking function which starts the security provider, handling
// Run is a blocking function which starts the security provider, handling
// rotation of credentials.
func (p *provider) Start(ctx context.Context) error {
func (p *provider) Run(ctx context.Context) error {
if !p.running.CompareAndSwap(false, true) {
return errors.New("security provider already started")
}
@ -238,17 +240,28 @@ func (p *provider) Handler(ctx context.Context) (Handler, error) {
}
}
// GRPCServerOption returns a gRPC server option which instruments
// GRPCDialOptionMTLS returns a gRPC dial option which instruments client
// authentication using the current signed client certificate.
func (s *security) GRPCDialOptionMTLS(appID spiffeid.ID) grpc.DialOption {
if !s.mtls {
return grpc.WithTransportCredentials(insecure.NewCredentials())
}
return grpc.WithTransportCredentials(credentials.NewTLS(
legacy.NewDialClient(s.source, s.source, tlsconfig.AuthorizeID(appID)),
))
}
// GRPCServerOptionMTLS returns a gRPC server option which instruments
// authentication of clients using the current trust anchors.
func (s *security) GRPCServerOption() grpc.ServerOption {
func (s *security) GRPCServerOptionMTLS() grpc.ServerOption {
if !s.mtls {
return grpc.Creds(insecure.NewCredentials())
}
// TODO: It would be better if we could give a subset of trust domains in
// which this server authorizes.
return grpc.Creds(
credentials.NewTLS(legacy.NewServer(s.source, s.source, tlsconfig.AuthorizeAny())),
// TODO: It would be better if we could give a subset of trust domains in
// which this server authorizes.
grpccredentials.MTLSServerCredentials(s.source, s.source, tlsconfig.AuthorizeAny()),
)
}
@ -256,9 +269,54 @@ func (s *security) GRPCServerOption() grpc.ServerOption {
// authentication of clients using the current trust anchors. Doesn't require
// clients to present a certificate.
func (s *security) GRPCServerOptionNoClientAuth() grpc.ServerOption {
return grpc.Creds(
grpccredentials.TLSServerCredentials(s.source),
)
return grpc.Creds(grpccredentials.TLSServerCredentials(s.source))
}
// GRPCDialOptionMTLSUnknownTrustDomain returns a gRPC dial option which
// instruments client authentication using the current signed client
// certificate. Doesn't verify the servers trust domain, but does authorize the
// SPIFFE ID path.
// Used for clients which don't know the servers Trust Domain.
func (s *security) GRPCDialOptionMTLSUnknownTrustDomain(ns, appID string) grpc.DialOption {
if !s.mtls {
return grpc.WithTransportCredentials(insecure.NewCredentials())
}
expID := "/ns/" + ns + "/" + appID
matcher := func(actual spiffeid.ID) error {
if actual.Path() != expID {
return fmt.Errorf("unexpected SPIFFE ID: %q", actual)
}
return nil
}
return grpc.WithTransportCredentials(credentials.NewTLS(
legacy.NewDialClient(s.source, s.source, tlsconfig.AdaptMatcher(matcher)),
))
}
// CurrentTrustAnchors returns the current trust anchors for this Dapr
// installation.
func (s *security) CurrentTrustAnchors() ([]byte, error) {
if s.source == nil {
return nil, nil
}
ta, err := s.source.trustAnchors.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal trust anchors: %w", err)
}
return ta, nil
}
// ControlPlaneTrustDomain returns the trust domain of the control plane.
func (s *security) ControlPlaneTrustDomain() spiffeid.TrustDomain {
return s.controlPlaneTrustDomain
}
// ControlPlaneNamespace returns the dapr namespace of the control plane.
func (s *security) ControlPlaneNamespace() string {
return s.controlPlaneNamespace
}
// WatchTrustAnchors watches for changes to the trust domains and returns the
@ -296,23 +354,6 @@ func (s *security) TLSServerConfigNoClientAuth() *tls.Config {
return tlsconfig.TLSServerConfig(s.source)
}
// CurrentTrustAnchors returns the current trust anchors for this Dapr
// installation.
func (s *security) CurrentTrustAnchors() ([]byte, error) {
return s.source.trustAnchors.Marshal()
}
// GRPCDialOption returns a gRPC dial option which instruments client
// authentication using the current signed client certificate.
func (s *security) GRPCDialOption(appID spiffeid.ID) grpc.DialOption {
if !s.mtls {
return grpc.WithTransportCredentials(insecure.NewCredentials())
}
return grpc.WithTransportCredentials(
grpccredentials.MTLSClientCredentials(s.source, s.source, tlsconfig.AuthorizeID(appID)),
)
}
// NetListenerID returns a mTLS net listener which instruments using the
// current signed server certificate. Authorizes client matches against the
// given SPIFFE ID.
@ -338,14 +379,6 @@ func (s *security) NetDialerID(ctx context.Context, spiffeID spiffeid.ID, timeou
}).Dial
}
func (s *security) ControlPlaneTrustDomain() spiffeid.TrustDomain {
return s.controlPlaneTrustDomain
}
func (s *security) ControlPlaneNamespace() string {
return s.controlPlaneNamespace
}
// CurrentNamespace returns the namespace of this workload.
func CurrentNamespace() string {
namespace, ok := os.LookupEnv("NAMESPACE")
@ -356,8 +389,8 @@ func CurrentNamespace() string {
}
// SentryID returns the SPIFFE ID of the sentry server.
func SentryID(td spiffeid.TrustDomain, sentryNamespace string) (spiffeid.ID, error) {
sentryID, err := spiffeid.FromSegments(td, "ns", sentryNamespace, "dapr-sentry")
func SentryID(sentryTrustDomain spiffeid.TrustDomain, sentryNamespace string) (spiffeid.ID, error) {
sentryID, err := spiffeid.FromSegments(sentryTrustDomain, "ns", sentryNamespace, "dapr-sentry")
if err != nil {
return spiffeid.ID{}, fmt.Errorf("failed to parse sentry SPIFFE ID: %w", err)
}

View File

@ -106,7 +106,7 @@ func Test_Start(t *testing.T) {
providerStopped := make(chan struct{})
go func() {
defer close(providerStopped)
require.NoError(t, p.Start(ctx))
require.NoError(t, p.Run(ctx))
}()
prov := p.(*provider)

View File

@ -15,7 +15,9 @@ package spiffe
import (
"context"
"errors"
"fmt"
"net/url"
"strings"
"github.com/spiffe/go-spiffe/v2/spiffegrpc/grpccredentials"
@ -24,42 +26,77 @@ import (
// Parsed is a parsed SPIFFE ID according to the Dapr SPIFFE ID path format.
type Parsed struct {
TrustDomain string
Namespace string
AppID string
id spiffeid.ID
namespace string
appID string
}
// FromContext parses a SPIFFE ID from a gRPC context.
func FromContext(ctx context.Context) (Parsed, bool, error) {
// FromGRPCContext parses a SPIFFE ID from a gRPC context.
func FromGRPCContext(ctx context.Context) (*Parsed, bool, error) {
// Apply access control list filter
id, ok := grpccredentials.PeerIDFromContext(ctx)
if !ok {
return Parsed{}, false, nil
return nil, false, nil
}
return fromID(id)
}
func fromID(id spiffeid.ID) (Parsed, bool, error) {
split := strings.Split(id.Path(), "/")
// Don't force match of 4 segments, since we may want to add more identifiers
// to the path in future which would otherwise break backwards compat.
if len(split) < 4 || split[0] != "" || split[1] != "ns" {
return Parsed{}, false, fmt.Errorf("malformed SPIFFE ID: %s", id.String())
return nil, false, fmt.Errorf("malformed SPIFFE ID: %s", id.String())
}
return Parsed{
TrustDomain: id.TrustDomain().String(),
Namespace: split[2],
AppID: split[3],
return &Parsed{
id: id,
namespace: split[2],
appID: split[3],
}, true, nil
}
func (p Parsed) ToID() (spiffeid.ID, error) {
td, err := spiffeid.TrustDomainFromString(p.TrustDomain)
if err != nil {
return spiffeid.ID{}, err
// FromStrings builds a Dapr SPIFFE ID with the given namespace and app ID in
// the given Trust Domain.
func FromStrings(td spiffeid.TrustDomain, namespace, appID string) (*Parsed, error) {
if len(td.String()) == 0 || len(namespace) == 0 || len(appID) == 0 {
return nil, errors.New("malformed SPIFFE ID")
}
return spiffeid.FromSegments(td, "ns", p.Namespace, p.AppID)
id, err := spiffeid.FromSegments(td, "ns", namespace, appID)
if err != nil {
return nil, err
}
return &Parsed{
id: id,
namespace: namespace,
appID: appID,
}, nil
}
func (p *Parsed) TrustDomain() spiffeid.TrustDomain {
if p == nil {
return spiffeid.TrustDomain{}
}
return p.id.TrustDomain()
}
func (p *Parsed) AppID() string {
if p == nil {
return ""
}
return p.appID
}
func (p *Parsed) Namespace() string {
if p == nil {
return ""
}
return p.namespace
}
func (p *Parsed) URL() *url.URL {
if p == nil {
return new(url.URL)
}
return p.id.URL()
}

View File

@ -15,93 +15,51 @@ package spiffe
import (
"errors"
"fmt"
"testing"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/stretchr/testify/assert"
)
func TestFromID(t *testing.T) {
func TestFromStrings(t *testing.T) {
tests := map[string]struct {
id spiffeid.ID
expOK bool
td spiffeid.TrustDomain
appID string
ns string
expErr error
expID Parsed
expID *Parsed
}{
"valid SPIFFE ID": {
id: spiffeid.RequireFromSegments(spiffeid.RequireTrustDomainFromString("example.org"), "ns", "test", "app"),
expOK: true,
expID: Parsed{
TrustDomain: "example.org",
Namespace: "test",
AppID: "app",
},
},
"valid SPIFFE ID with extra identifiers": {
id: spiffeid.RequireFromSegments(spiffeid.RequireTrustDomainFromString("example.org"), "ns", "test", "app", "extra", "identifiers"),
expOK: true,
expID: Parsed{
TrustDomain: "example.org",
Namespace: "test",
AppID: "app",
td: spiffeid.RequireTrustDomainFromString("example.org"),
ns: "test",
appID: "app",
expID: &Parsed{
id: spiffeid.RequireFromString("spiffe://example.org/ns/test/app"),
namespace: "test",
appID: "app",
},
},
"SPIFFE ID: no namespace": {
id: spiffeid.RequireFromSegments(spiffeid.RequireTrustDomainFromString("example.org"), "test", "bar", "app"),
expOK: false,
expErr: errors.New("malformed SPIFFE ID: spiffe://example.org/test/bar/app"),
expID: Parsed{},
td: spiffeid.RequireTrustDomainFromString("example.org"),
ns: "",
appID: "app",
expErr: errors.New("malformed SPIFFE ID"),
expID: nil,
},
"SPIFFE ID: too short": {
id: spiffeid.RequireFromPath(spiffeid.RequireTrustDomainFromString("example.org"), "/a"),
expOK: false,
expErr: errors.New("malformed SPIFFE ID: spiffe://example.org/a"),
expID: Parsed{},
"SPIFFE ID: no app ID": {
td: spiffeid.RequireTrustDomainFromString("example.org"),
ns: "test",
appID: "",
expErr: errors.New("malformed SPIFFE ID"),
expID: nil,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
id, ok, err := fromID(test.id)
assert.Equal(t, test.expOK, ok)
id, err := FromStrings(test.td, test.ns, test.appID)
assert.Equal(t, test.expErr, err)
assert.Equal(t, test.expID, id)
})
}
}
func TestToID(t *testing.T) {
tests := map[string]struct {
parsed Parsed
expID spiffeid.ID
expErr error
}{
"valid parsed SPIFFE ID": {
parsed: Parsed{
TrustDomain: "example.org",
Namespace: "test",
AppID: "app",
},
expID: spiffeid.RequireFromSegments(spiffeid.RequireTrustDomainFromString("example.org"), "ns", "test", "app"),
},
"invalid trust domain": {
parsed: Parsed{
TrustDomain: "invalid^&%$^%$",
Namespace: "test",
AppID: "app",
},
expErr: fmt.Errorf("trust domain characters are limited to lowercase letters, numbers, dots, dashes, and underscores"),
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
id, err := test.parsed.ToID()
assert.Equal(t, test.expID, id)
assert.Equal(t, test.expErr, err)
})
}
}

View File

@ -33,15 +33,16 @@ import (
middleware "github.com/grpc-ecosystem/go-grpc-middleware"
retry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffegrpc/grpccredentials"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"k8s.io/utils/clock"
"github.com/dapr/dapr/pkg/diagnostics"
sentryv1pb "github.com/dapr/dapr/pkg/proto/sentry/v1"
"github.com/dapr/dapr/pkg/security/legacy"
secpem "github.com/dapr/dapr/pkg/security/pem"
sentryToken "github.com/dapr/dapr/pkg/security/token"
)
@ -281,11 +282,14 @@ func (x *x509source) requestFromSentry(ctx context.Context, csrDER []byte) ([]*x
)
}
tlsConfig, err := legacy.NewDialClientOptionalClientAuth(x, x, tlsconfig.AuthorizeID(x.sentryID))
if err != nil {
return nil, fmt.Errorf("error creating tls config: %w", err)
}
conn, err := grpc.DialContext(ctx,
x.sentryAddress,
grpc.WithTransportCredentials(
grpccredentials.TLSClientCredentials(x, tlsconfig.AuthorizeID(x.sentryID)),
),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
grpc.WithUnaryInterceptor(unaryClientInterceptor),
grpc.WithReturnConnectionError(),
)
@ -306,7 +310,7 @@ func (x *x509source) requestFromSentry(ctx context.Context, csrDER []byte) ([]*x
CertificateSigningRequest: pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST", Bytes: csrDER,
}),
Id: x.appID,
Id: getSentryIdentifier(x.appID),
Token: token,
Namespace: x.appNamespace,
TokenValidator: tokenValidator,
@ -314,6 +318,13 @@ func (x *x509source) requestFromSentry(ctx context.Context, csrDER []byte) ([]*x
if x.trustDomain != nil {
req.TrustDomain = *x.trustDomain
} else {
// For v1.11 sentry, if the trust domain is empty in the request then it
// will return an empty certificate so we default to `public` here to
// ensure we get an identity certificate back.
// This request field is ignored for non control-plane requests in v1.12.
// TODO: @joshvanl: Remove in v1.13.
req.TrustDomain = "public"
}
resp, err := sentryv1pb.NewCAClient(conn).SignCertificate(ctx, req)
@ -475,3 +486,12 @@ func isControlPlaneService(id string) bool {
return false
}
}
func getSentryIdentifier(appID string) string {
// return injected identity, default id if not present
localID := os.Getenv("SENTRY_LOCAL_IDENTITY")
if localID != "" {
return localID
}
return appID
}

View File

@ -119,7 +119,7 @@ func (s *sentry) Start(parentCtx context.Context) error {
// Start all background processes
runners := concurrency.NewRunnerManager(
provider.Start,
provider.Run,
func(ctx context.Context) error {
sec, secErr := provider.Handler(ctx)
if secErr != nil {

View File

@ -20,6 +20,7 @@ import (
"crypto/x509"
"time"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
@ -138,11 +139,12 @@ func New(ctx context.Context, conf config.Config) (Signer, error) {
}
func (c *ca) SignIdentity(ctx context.Context, req *SignRequest, overrideDuration bool) ([]*x509.Certificate, error) {
spiffeID, err := (spiffe.Parsed{
TrustDomain: req.TrustDomain,
Namespace: req.Namespace,
AppID: req.AppID,
}).ToID()
td, err := spiffeid.TrustDomainFromString(req.TrustDomain)
if err != nil {
return nil, err
}
spiffeID, err := spiffe.FromStrings(td, req.Namespace, req.AppID)
if err != nil {
return nil, err
}

View File

@ -86,17 +86,18 @@ func generateIssuerCert(trustDomain string, skew time.Duration) (*x509.Certifica
return nil, err
}
sentryID, err := (spiffe.Parsed{
TrustDomain: trustDomain,
Namespace: security.CurrentNamespace(),
AppID: "dapr-sentry",
}).ToID()
td, err := spiffeid.TrustDomainFromString(trustDomain)
if err != nil {
return nil, err
}
sentryID, err := spiffe.FromStrings(td, security.CurrentNamespace(), "dapr-sentry")
if err != nil {
return nil, fmt.Errorf("failed to generate sentry ID: %w", err)
}
cert.KeyUsage |= x509.KeyUsageCertSign | x509.KeyUsageCRLSign
cert.Subject = pkix.Name{Organization: []string{sentryID.String()}}
cert.Subject = pkix.Name{Organization: []string{sentryID.URL().String()}}
cert.IsCA = true
cert.BasicConstraintsValid = true
cert.SignatureAlgorithm = x509.ECDSAWithSHA256
@ -109,7 +110,7 @@ func generateIssuerCert(trustDomain string, skew time.Duration) (*x509.Certifica
}
// generateWorkloadCert returns a CA issuing x509 Certificate.
func generateWorkloadCert(sig x509.SignatureAlgorithm, ttl, skew time.Duration, id spiffeid.ID) (*x509.Certificate, error) {
func generateWorkloadCert(sig x509.SignatureAlgorithm, ttl, skew time.Duration, id *spiffe.Parsed) (*x509.Certificate, error) {
cert, err := generateBaseCert(ttl, skew)
if err != nil {
return nil, err

View File

@ -151,10 +151,10 @@ func (s *server) signCertificate(ctx context.Context, req *sentryv1pb.SignCertif
}
// TODO: @joshvanl: before v1.12, daprd was matching on
// `<app-id>.<namespace>.svc.cluster.local`/`cluster.local` DNS SAN name so
// without this, daprd->daprd connections would fail. This is no longer the
// case since we now match with SPIFFE URI SAN, but we need to keep this here
// for backwards compatibility. Remove after v1.14.
// `<app-id>.<namespace>.svc.cluster.local` DNS SAN name so without this,
// daprd->daprd connections would fail. This is no longer the case since we
// now match with SPIFFE URI SAN, but we need to keep this here for backwards
// compatibility. Remove after v1.14.
var dns []string
switch {
case req.Namespace == security.CurrentNamespace() && req.Id == "dapr-injector":

View File

@ -28,7 +28,7 @@ type FailingAppChannel struct {
KeyFunc func(req *invokev1.InvokeMethodRequest) string
}
func (f *FailingAppChannel) GetAppConfig(appID string) (*config.ApplicationConfig, error) {
func (f *FailingAppChannel) GetAppConfig(_ context.Context, appID string) (*config.ApplicationConfig, error) {
return nil, nil
}

View File

@ -125,7 +125,7 @@ var allowListsForServiceInvocationTests = []struct {
"opDeny",
"allowlists-callee-http",
"opDeny",
"failed to invoke, id: allowlists-callee-http, err: rpc error: code = PermissionDenied desc = access control policy has denied access to appid: allowlists-caller operation: opDeny verb: POST",
"failed to invoke, id: allowlists-callee-http, err: rpc error: code = PermissionDenied desc = access control policy has denied access to id: spiffe://public/ns/dapr-tests/allowlists-caller operation: opDeny verb: POST",
"http",
403,
},
@ -152,7 +152,7 @@ var allowListsForServiceInvocationTests = []struct {
"httptogrpctest",
"allowlists-callee-grpc",
"httptogrpctest",
"HTTP call failed with failed to invoke, id: allowlists-callee-grpc, err: rpc error: code = PermissionDenied desc = access control policy has denied access to appid: allowlists-caller operation: httpToGrpcTest verb: NONE",
"HTTP call failed with failed to invoke, id: allowlists-callee-grpc, err: rpc error: code = PermissionDenied desc = access control policy has denied access to id: spiffe://public/ns/dapr-tests/allowlists-caller operation: httpToGrpcTest verb: NONE",
"grpc",
403,
},

View File

@ -41,11 +41,6 @@ func InitHTTPClient(allowHTTP2 bool) {
httpClient = NewHTTPClient(allowHTTP2)
}
// GetHTTPClient returns the shared httpClient object.
func GetHTTPClient() *http.Client {
return httpClient
}
// NewHTTPClient initializes a new *http.Client.
// This should not be used except in rare circumstances. Developers should use the shared httpClient instead to re-use sockets as much as possible.
func NewHTTPClient(allowHTTP2 bool) *http.Client {
@ -281,3 +276,8 @@ func SanitizeHTTPURL(url string) string {
return url
}
// GetHTTPClient returns the shared httpClient object.
func GetHTTPClient() *http.Client {
return httpClient
}

View File

@ -93,7 +93,7 @@ func (i *insecure) Run(t *testing.T, ctx context.Context) {
errCh := make(chan error, 1)
go func() {
errCh <- secProv.Start(ctx)
errCh <- secProv.Run(ctx)
}()
t.Cleanup(func() { cancel(); require.NoError(t, <-errCh) })
@ -113,7 +113,7 @@ func (i *insecure) Run(t *testing.T, ctx context.Context) {
}
host := "localhost:" + strconv.Itoa(i.places[j].Port())
conn, cerr := grpc.DialContext(ctx, host, grpc.WithBlock(),
grpc.WithReturnConnectionError(), sec.GRPCDialOption(placeID),
grpc.WithReturnConnectionError(), sec.GRPCDialOptionMTLS(placeID),
)
if cerr != nil {
return false

View File

@ -138,7 +138,7 @@ func (j *jwks) Run(t *testing.T, ctx context.Context) {
errCh := make(chan error, 1)
go func() {
errCh <- secProv.Start(ctx)
errCh <- secProv.Run(ctx)
}()
t.Cleanup(func() { cancel(); require.NoError(t, <-errCh) })
@ -158,7 +158,7 @@ func (j *jwks) Run(t *testing.T, ctx context.Context) {
}
host := "localhost:" + strconv.Itoa(j.places[i].Port())
conn, cerr := grpc.DialContext(ctx, host, grpc.WithBlock(),
grpc.WithReturnConnectionError(), sec.GRPCDialOption(placeID),
grpc.WithReturnConnectionError(), sec.GRPCDialOptionMTLS(placeID),
)
if cerr != nil {
return false