diff --git a/pubsub/kafka/kafka.go b/pubsub/kafka/kafka.go index 0f39a8af0..e12963b98 100644 --- a/pubsub/kafka/kafka.go +++ b/pubsub/kafka/kafka.go @@ -8,7 +8,9 @@ package kafka import ( "context" "crypto/tls" + "crypto/x509" "encoding/base64" + "encoding/pem" "errors" "fmt" "strconv" @@ -24,7 +26,11 @@ import ( ) const ( - key = "partitionKey" + key = "partitionKey" + skipVerify = "skipVerify" + caCert = "caCert" + clientCert = "clientCert" + clientKey = "clientKey" ) // Kafka allows reading/writing to a Kafka consumer group. @@ -47,14 +53,18 @@ type Kafka struct { } type kafkaMetadata struct { - Brokers []string `json:"brokers"` - ConsumerGroup string `json:"consumerGroup"` - ClientID string `json:"clientID"` - AuthRequired bool `json:"authRequired"` - SaslUsername string `json:"saslUsername"` - SaslPassword string `json:"saslPassword"` - InitialOffset int64 `json:"initialOffset"` - MaxMessageBytes int `json:"maxMessageBytes"` + Brokers []string + ConsumerGroup string + ClientID string + AuthRequired bool + SaslUsername string + SaslPassword string + InitialOffset int64 + MaxMessageBytes int + TLSSkipVerify bool + TLSCaCert string + TLSClientCert string + TLSClientKey string } type consumer struct { @@ -137,6 +147,10 @@ func (k *Kafka) Init(metadata pubsub.Metadata) error { k.saslPassword = meta.SaslPassword updateAuthInfo(config, k.saslUsername, k.saslPassword) } + err = updateTLSConfig(config, meta) + if err != nil { + return err + } k.config = config @@ -350,9 +364,49 @@ func (k *Kafka) getKafkaMetadata(metadata pubsub.Metadata) (*kafkaMetadata, erro meta.MaxMessageBytes = maxBytes } + if val, ok := metadata.Properties[clientCert]; ok && val != "" { + if !isValidPEM(val) { + return nil, errors.New("kafka error: invalid client certificate") + } + meta.TLSClientCert = val + } + if val, ok := metadata.Properties[clientKey]; ok && val != "" { + if !isValidPEM(val) { + return nil, errors.New("kafka error: invalid client key") + } + meta.TLSClientKey = val + } + // clientKey and clientCert need to be all specified or all not specified. + if (meta.TLSClientKey == "") != (meta.TLSClientCert == "") { + return nil, errors.New("kafka error: clientKey or clientCert is missing") + } + if val, ok := metadata.Properties[caCert]; ok && val != "" { + if !isValidPEM(val) { + return nil, errors.New("kafka error: invalid ca certificate") + } + meta.TLSCaCert = val + } + if val, ok := metadata.Properties[skipVerify]; ok && val != "" { + boolVal, err := strconv.ParseBool(val) + if err != nil { + return nil, fmt.Errorf("kafka error: invalid value for '%s' attribute: %w", skipVerify, err) + } + meta.TLSSkipVerify = boolVal + if boolVal { + k.logger.Infof("kafka: you are using 'skipVerify' to skip server config verify which is unsafe!") + } + } + return &meta, nil } +// isValidPEM validates the provided input has PEM formatted block. +func isValidPEM(val string) bool { + block, _ := pem.Decode([]byte(val)) + + return block != nil +} + func getSyncProducer(config sarama.Config, brokers []string, maxMessageBytes int) (sarama.SyncProducer, error) { // Add SyncProducer specific properties to copy of base config config.Producer.RequiredAcks = sarama.WaitForAll @@ -376,13 +430,31 @@ func updateAuthInfo(config *sarama.Config, saslUsername, saslPassword string) { config.Net.SASL.User = saslUsername config.Net.SASL.Password = saslPassword config.Net.SASL.Mechanism = sarama.SASLTypePlaintext +} +func updateTLSConfig(config *sarama.Config, metadata *kafkaMetadata) error { + if !metadata.TLSSkipVerify && metadata.TLSCaCert == "" && metadata.TLSClientCert == "" { + return nil + } config.Net.TLS.Enable = true // nolint: gosec - config.Net.TLS.Config = &tls.Config{ - // InsecureSkipVerify: true, - ClientAuth: 0, + config.Net.TLS.Config = &tls.Config{InsecureSkipVerify: metadata.TLSSkipVerify} + if metadata.TLSClientCert != "" && metadata.TLSClientKey != "" { + cert, err := tls.X509KeyPair([]byte(metadata.TLSClientCert), []byte(metadata.TLSClientKey)) + if err != nil { + return fmt.Errorf("unable to load client certificate and key pair. Err: %w", err) + } + config.Net.TLS.Config.Certificates = []tls.Certificate{cert} } + if metadata.TLSCaCert != "" { + caCertPool := x509.NewCertPool() + if ok := caCertPool.AppendCertsFromPEM([]byte(metadata.TLSCaCert)); !ok { + return errors.New("kafka error: unable to load ca certificate") + } + config.Net.TLS.Config.RootCAs = caCertPool + } + + return nil } func (k *Kafka) Close() error { diff --git a/pubsub/kafka/kafka_test.go b/pubsub/kafka/kafka_test.go index 5cb716d94..16fc278c9 100644 --- a/pubsub/kafka/kafka_test.go +++ b/pubsub/kafka/kafka_test.go @@ -17,13 +17,34 @@ import ( "github.com/dapr/kit/logger" ) +var ( + clientCertPemMock = `-----BEGIN CERTIFICATE----- +Y2xpZW50Q2VydA== +-----END CERTIFICATE-----` + clientKeyMock = `-----BEGIN RSA PRIVATE KEY----- +Y2xpZW50S2V5 +-----END RSA PRIVATE KEY-----` + caCertMock = `-----BEGIN CERTIFICATE----- +Y2FDZXJ0 +-----END CERTIFICATE-----` +) + func getKafkaPubsub() *Kafka { return &Kafka{logger: logger.NewLogger("kafka_test")} } -func TestParseMetadata(t *testing.T) { +func getBaseMetadata() pubsub.Metadata { m := pubsub.Metadata{} m.Properties = map[string]string{"consumerGroup": "a", "clientID": "a", "brokers": "a", "authRequired": "false", "maxMessageBytes": "2048"} + return m +} + +func TestParseMetadata(t *testing.T) { + m := pubsub.Metadata{} + m.Properties = map[string]string{ + "consumerGroup": "a", "clientID": "a", "brokers": "a", "authRequired": "false", "maxMessageBytes": "2048", + skipVerify: "true", clientCert: clientCertPemMock, clientKey: clientKeyMock, caCert: caCertMock, + } k := getKafkaPubsub() meta, err := k.getKafkaMetadata(m) assert.Nil(t, err) @@ -31,6 +52,10 @@ func TestParseMetadata(t *testing.T) { assert.Equal(t, "a", meta.ConsumerGroup) assert.Equal(t, "a", meta.ClientID) assert.Equal(t, 2048, meta.MaxMessageBytes) + assert.Equal(t, true, meta.TLSSkipVerify) + assert.Equal(t, clientCertPemMock, meta.TLSClientCert) + assert.Equal(t, clientKeyMock, meta.TLSClientKey) + assert.Equal(t, caCertMock, meta.TLSCaCert) } func TestMissingBrokers(t *testing.T) { @@ -114,3 +139,68 @@ func TestInitialOffset(t *testing.T) { require.NoError(t, err) assert.Equal(t, sarama.OffsetNewest, meta.InitialOffset) } + +func TestTls(t *testing.T) { + k := getKafkaPubsub() + + t.Run("disable tls", func(t *testing.T) { + m := getBaseMetadata() + meta, err := k.getKafkaMetadata(m) + require.NoError(t, err) + assert.NotNil(t, meta) + c := &sarama.Config{} + err = updateTLSConfig(c, meta) + require.NoError(t, err) + assert.Equal(t, false, c.Net.TLS.Enable) + }) + + t.Run("wrong client cert format", func(t *testing.T) { + m := getBaseMetadata() + m.Properties[clientCert] = "clientCert" + meta, err := k.getKafkaMetadata(m) + assert.Error(t, err) + assert.Nil(t, meta) + + assert.Equal(t, "kafka error: invalid client certificate", err.Error()) + }) + + t.Run("wrong client key format", func(t *testing.T) { + m := getBaseMetadata() + m.Properties[clientKey] = "clientKey" + meta, err := k.getKafkaMetadata(m) + assert.Error(t, err) + assert.Nil(t, meta) + + assert.Equal(t, "kafka error: invalid client key", err.Error()) + }) + + t.Run("miss client key", func(t *testing.T) { + m := getBaseMetadata() + m.Properties[clientCert] = clientCertPemMock + meta, err := k.getKafkaMetadata(m) + assert.Error(t, err) + assert.Nil(t, meta) + + assert.Equal(t, "kafka error: clientKey or clientCert is missing", err.Error()) + }) + + t.Run("miss client cert", func(t *testing.T) { + m := getBaseMetadata() + m.Properties[clientKey] = clientKeyMock + meta, err := k.getKafkaMetadata(m) + assert.Error(t, err) + assert.Nil(t, meta) + + assert.Equal(t, "kafka error: clientKey or clientCert is missing", err.Error()) + }) + + t.Run("wrong ca cert format", func(t *testing.T) { + m := getBaseMetadata() + m.Properties[caCert] = "caCert" + meta, err := k.getKafkaMetadata(m) + assert.Error(t, err) + assert.Nil(t, meta) + + assert.Equal(t, "kafka error: invalid ca certificate", err.Error()) + }) +}