From 28380512f3adacc1acb54daa381c4c71b1a1fa8f Mon Sep 17 00:00:00 2001 From: Liangliang Gu Date: Tue, 28 Sep 2021 11:13:11 +0800 Subject: [PATCH] support tls (#280) --- README.md | 18 +++++ pom.xml | 6 ++ .../java/org/tikv/common/ConfigUtils.java | 6 ++ .../java/org/tikv/common/TiConfiguration.java | 38 +++++++++++ src/main/java/org/tikv/common/TiSession.java | 13 +++- .../org/tikv/common/util/ChannelFactory.java | 68 ++++++++++++++++--- 6 files changed, 138 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 895d72dd3f..6650c1b416 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,24 @@ The following includes ThreadPool related parameters, which can be passed in thr - whether to enable `Compare And Set`, set true if using `RawKVClient.compareAndSet` or `RawKVClient.putIfAbsent` - default: false +### TLS + +#### tikv.tls_enable +- whether to enable TLS +- default: false + +#### tikv.trust_cert_collection +- Trusted certificates for verifying the remote endpoint's certificate, e.g. /home/tidb/ca.pem. The file should contain an X.509 certificate collection in PEM format. +- default: null + +#### tikv.key_cert_chain +- an X.509 certificate chain file in PEM format, e.g. /home/tidb/client.pem. +- default: null + +#### tikv.key_file +- a PKCS#8 private key file in PEM format. e.g. /home/tidb/client-key.pem. +- default: null + ## Metrics Client Java supports exporting metrics to Prometheus using poll mode and viewing on Grafana. The following steps shows how to enable this function. diff --git a/pom.xml b/pom.xml index 09b8374f63..034c8e6921 100644 --- a/pom.xml +++ b/pom.xml @@ -65,6 +65,7 @@ 1.2.17 1.7.16 1.24.0 + 2.0.25.Final 1.6.6 2.12.3 3.0.1 @@ -133,6 +134,11 @@ grpc-services ${grpc.version} + + io.netty + netty-tcnative-boringssl-static + ${netty.tcnative.version} + io.grpc grpc-testing diff --git a/src/main/java/org/tikv/common/ConfigUtils.java b/src/main/java/org/tikv/common/ConfigUtils.java index de50e7c870..8374d6feb5 100644 --- a/src/main/java/org/tikv/common/ConfigUtils.java +++ b/src/main/java/org/tikv/common/ConfigUtils.java @@ -69,6 +69,11 @@ public class ConfigUtils { public static final String TIKV_RAWKV_DEFAULT_BACKOFF_IN_MS = "tikv.rawkv.default_backoff_in_ms"; + public static final String TIKV_TLS_ENABLE = "tikv.tls_enable"; + public static final String TIKV_TRUST_CERT_COLLECTION = "tikv.trust_cert_collection"; + public static final String TIKV_KEY_CERT_CHAIN = "tikv.key_cert_chain"; + public static final String TIKV_KEY_FILE = "tikv.key_file"; + public static final String DEF_PD_ADDRESSES = "127.0.0.1:2379"; public static final String DEF_TIMEOUT = "200ms"; public static final String DEF_TIKV_GRPC_INGEST_TIMEOUT = "200s"; @@ -125,4 +130,5 @@ public class ConfigUtils { public static final int DEF_TIKV_GRPC_KEEPALIVE_TIME = 10; public static final int DEF_TIKV_GRPC_KEEPALIVE_TIMEOUT = 3; + public static final boolean DEF_TIKV_TLS_ENABLE = false; } diff --git a/src/main/java/org/tikv/common/TiConfiguration.java b/src/main/java/org/tikv/common/TiConfiguration.java index eb3e64d80e..2080131088 100644 --- a/src/main/java/org/tikv/common/TiConfiguration.java +++ b/src/main/java/org/tikv/common/TiConfiguration.java @@ -90,6 +90,7 @@ public class TiConfiguration implements Serializable { setIfMissing(TIKV_RAWKV_DEFAULT_BACKOFF_IN_MS, DEF_TIKV_RAWKV_DEFAULT_BACKOFF_IN_MS); setIfMissing(TIKV_GRPC_KEEPALIVE_TIME, DEF_TIKV_GRPC_KEEPALIVE_TIME); setIfMissing(TIKV_GRPC_KEEPALIVE_TIMEOUT, DEF_TIKV_GRPC_KEEPALIVE_TIMEOUT); + setIfMissing(TIKV_TLS_ENABLE, DEF_TIKV_TLS_ENABLE); } public static void listAll() { @@ -291,6 +292,11 @@ public class TiConfiguration implements Serializable { private int rawKVDefaultBackoffInMS = getInt(TIKV_RAWKV_DEFAULT_BACKOFF_IN_MS); + private boolean tlsEnable = getBoolean(TIKV_TLS_ENABLE); + private String trustCertCollectionFile = getOption(TIKV_TRUST_CERT_COLLECTION).orElse(null); + private String keyCertChainFile = getOption(TIKV_KEY_CERT_CHAIN).orElse(null); + private String keyFile = getOption(TIKV_KEY_FILE).orElse(null); + private boolean isTest = false; private int keepaliveTime = getInt(TIKV_GRPC_KEEPALIVE_TIME); @@ -689,4 +695,36 @@ public class TiConfiguration implements Serializable { public void setKeepaliveTimeout(int timeout) { this.keepaliveTimeout = timeout; } + + public boolean isTlsEnable() { + return tlsEnable; + } + + public void setTlsEnable(boolean tlsEnable) { + this.tlsEnable = tlsEnable; + } + + public String getTrustCertCollectionFile() { + return trustCertCollectionFile; + } + + public void setTrustCertCollectionFile(String trustCertCollectionFile) { + this.trustCertCollectionFile = trustCertCollectionFile; + } + + public String getKeyCertChainFile() { + return keyCertChainFile; + } + + public void setKeyCertChainFile(String keyCertChainFile) { + this.keyCertChainFile = keyCertChainFile; + } + + public String getKeyFile() { + return keyFile; + } + + public void setKeyFile(String keyFile) { + this.keyFile = keyFile; + } } diff --git a/src/main/java/org/tikv/common/TiSession.java b/src/main/java/org/tikv/common/TiSession.java index 3866cb50e2..449d876564 100644 --- a/src/main/java/org/tikv/common/TiSession.java +++ b/src/main/java/org/tikv/common/TiSession.java @@ -79,8 +79,17 @@ public class TiSession implements AutoCloseable { public TiSession(TiConfiguration conf) { this.conf = conf; this.channelFactory = - new ChannelFactory( - conf.getMaxFrameSize(), conf.getKeepaliveTime(), conf.getKeepaliveTimeout()); + conf.isTlsEnable() + ? new ChannelFactory( + conf.getMaxFrameSize(), + conf.getKeepaliveTime(), + conf.getKeepaliveTimeout(), + conf.getTrustCertCollectionFile(), + conf.getKeyCertChainFile(), + conf.getKeyFile()) + : new ChannelFactory( + conf.getMaxFrameSize(), conf.getKeepaliveTime(), conf.getKeepaliveTimeout()); + this.client = PDClient.createRaw(conf, channelFactory); this.enableGrpcForward = conf.getEnableGrpcForward(); this.metricsServer = MetricsServer.getInstance(conf); diff --git a/src/main/java/org/tikv/common/util/ChannelFactory.java b/src/main/java/org/tikv/common/util/ChannelFactory.java index 5433ad5acd..a517473089 100644 --- a/src/main/java/org/tikv/common/util/ChannelFactory.java +++ b/src/main/java/org/tikv/common/util/ChannelFactory.java @@ -16,23 +16,60 @@ package org.tikv.common.util; import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.File; import java.net.URI; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.tikv.common.HostMapping; import org.tikv.common.pd.PDUtils; public class ChannelFactory implements AutoCloseable { + private static final Logger logger = LoggerFactory.getLogger(ChannelFactory.class); + private final int maxFrameSize; private final int keepaliveTime; private final int keepaliveTimeout; private final ConcurrentHashMap connPool = new ConcurrentHashMap<>(); + private final SslContextBuilder sslContextBuilder; public ChannelFactory(int maxFrameSize, int keepaliveTime, int keepaliveTimeout) { this.maxFrameSize = maxFrameSize; this.keepaliveTime = keepaliveTime; this.keepaliveTimeout = keepaliveTimeout; + this.sslContextBuilder = null; + } + + public ChannelFactory( + int maxFrameSize, + int keepaliveTime, + int keepaliveTimeout, + String trustCertCollectionFilePath, + String keyCertChainFilePath, + String keyFilePath) { + this.maxFrameSize = maxFrameSize; + this.keepaliveTime = keepaliveTime; + this.keepaliveTimeout = keepaliveTimeout; + this.sslContextBuilder = + getSslContextBuilder(trustCertCollectionFilePath, keyCertChainFilePath, keyFilePath); + } + + private SslContextBuilder getSslContextBuilder( + String trustCertCollectionFilePath, String keyCertChainFilePath, String keyFilePath) { + SslContextBuilder builder = GrpcSslContexts.forClient(); + if (trustCertCollectionFilePath != null) { + builder.trustManager(new File(trustCertCollectionFilePath)); + } + if (keyCertChainFilePath != null && keyFilePath != null) { + builder.keyManager(new File(keyCertChainFilePath), new File(keyFilePath)); + } + return builder; } public ManagedChannel getChannel(String addressStr, HostMapping hostMapping) { @@ -51,16 +88,29 @@ public class ChannelFactory implements AutoCloseable { } catch (Exception e) { throw new IllegalArgumentException("failed to get mapped address " + address, e); } + // Channel should be lazy without actual connection until first call // So a coarse grain lock is ok here - return ManagedChannelBuilder.forAddress(mappedAddr.getHost(), mappedAddr.getPort()) - .maxInboundMessageSize(maxFrameSize) - .keepAliveTime(keepaliveTime, TimeUnit.SECONDS) - .keepAliveTimeout(keepaliveTimeout, TimeUnit.SECONDS) - .keepAliveWithoutCalls(true) - .usePlaintext(true) - .idleTimeout(60, TimeUnit.SECONDS) - .build(); + NettyChannelBuilder builder = + NettyChannelBuilder.forAddress(mappedAddr.getHost(), mappedAddr.getPort()) + .maxInboundMessageSize(maxFrameSize) + .keepAliveTime(keepaliveTime, TimeUnit.SECONDS) + .keepAliveTimeout(keepaliveTimeout, TimeUnit.SECONDS) + .keepAliveWithoutCalls(true) + .idleTimeout(60, TimeUnit.SECONDS); + + if (sslContextBuilder == null) { + return builder.usePlaintext(true).build(); + } else { + SslContext sslContext = null; + try { + sslContext = sslContextBuilder.build(); + } catch (SSLException e) { + logger.error("create ssl context failed!", e); + return null; + } + return builder.sslContext(sslContext).build(); + } }); }