diff --git a/src/main/java/org/tikv/common/util/ChannelFactory.java b/src/main/java/org/tikv/common/util/ChannelFactory.java index 4d9bcbd5ec..f91719352f 100644 --- a/src/main/java/org/tikv/common/util/ChannelFactory.java +++ b/src/main/java/org/tikv/common/util/ChannelFactory.java @@ -17,6 +17,7 @@ package org.tikv.common.util; +import com.google.common.annotations.VisibleForTesting; import io.grpc.ManagedChannel; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NettyChannelBuilder; @@ -28,6 +29,7 @@ import java.net.URI; import java.security.KeyStore; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLException; import javax.net.ssl.TrustManagerFactory; @@ -44,16 +46,141 @@ public class ChannelFactory implements AutoCloseable { private final int keepaliveTimeout; private final int idleTimeout; private final ConcurrentHashMap connPool = new ConcurrentHashMap<>(); - private final SslContextBuilder sslContextBuilder; + private final CertContext certContext; + private final AtomicReference sslContextBuilder = new AtomicReference<>(); private static final String PUB_KEY_INFRA = "PKIX"; + private abstract static class CertContext { + protected abstract boolean isModified(); + + protected abstract SslContextBuilder createSslContextBuilder(); + + public SslContextBuilder reload() { + if (isModified()) { + logger.info("reload ssl context"); + return createSslContextBuilder(); + } + return null; + } + } + + private static class JksContext extends CertContext { + private long keyLastModified; + private long trustLastModified; + + private final String keyPath; + private final String keyPassword; + private final String trustPath; + private final String trustPassword; + + public JksContext(String keyPath, String keyPassword, String trustPath, String trustPassword) { + this.keyLastModified = 0; + this.trustLastModified = 0; + + this.keyPath = keyPath; + this.keyPassword = keyPassword; + this.trustPath = trustPath; + this.trustPassword = trustPassword; + } + + @Override + protected synchronized boolean isModified() { + long a = new File(keyPath).lastModified(); + long b = new File(trustPath).lastModified(); + + boolean changed = this.keyLastModified != a || this.trustLastModified != b; + + if (changed) { + this.keyLastModified = a; + this.trustLastModified = b; + } + + return changed; + } + + @Override + protected SslContextBuilder createSslContextBuilder() { + SslContextBuilder builder = GrpcSslContexts.forClient(); + try { + if (keyPath != null && keyPassword != null) { + KeyStore keyStore = KeyStore.getInstance("JKS"); + keyStore.load(new FileInputStream(keyPath), keyPassword.toCharArray()); + KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, keyPassword.toCharArray()); + builder.keyManager(keyManagerFactory); + } + if (trustPath != null && trustPassword != null) { + KeyStore trustStore = KeyStore.getInstance("JKS"); + trustStore.load(new FileInputStream(trustPath), trustPassword.toCharArray()); + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(PUB_KEY_INFRA); + trustManagerFactory.init(trustStore); + builder.trustManager(trustManagerFactory); + } + } catch (Exception e) { + logger.error("JKS SSL context builder failed!", e); + } + return builder; + } + } + + private static class OpenSslContext extends CertContext { + private long trustLastModified; + private long chainLastModified; + private long keyLastModified; + + private final String trustPath; + private final String chainPath; + private final String keyPath; + + public OpenSslContext(String trustPath, String chainPath, String keyPath) { + this.trustLastModified = 0; + this.chainLastModified = 0; + this.keyLastModified = 0; + + this.trustPath = trustPath; + this.chainPath = chainPath; + this.keyPath = keyPath; + } + + @Override + protected synchronized boolean isModified() { + long a = new File(trustPath).lastModified(); + long b = new File(chainPath).lastModified(); + long c = new File(keyPath).lastModified(); + + boolean changed = + this.trustLastModified != a || this.chainLastModified != b || this.keyLastModified != c; + + if (changed) { + this.trustLastModified = a; + this.chainLastModified = b; + this.keyLastModified = c; + } + + return changed; + } + + @Override + protected SslContextBuilder createSslContextBuilder() { + SslContextBuilder builder = GrpcSslContexts.forClient(); + if (trustPath != null) { + builder.trustManager(new File(trustPath)); + } + if (chainPath != null && keyPath != null) { + builder.keyManager(new File(chainPath), new File(keyPath)); + } + return builder; + } + } + public ChannelFactory( int maxFrameSize, int keepaliveTime, int keepaliveTimeout, int idleTimeout) { this.maxFrameSize = maxFrameSize; this.keepaliveTime = keepaliveTime; this.keepaliveTimeout = keepaliveTimeout; this.idleTimeout = idleTimeout; - this.sslContextBuilder = null; + this.certContext = null; } public ChannelFactory( @@ -68,8 +195,8 @@ public class ChannelFactory implements AutoCloseable { this.keepaliveTime = keepaliveTime; this.keepaliveTimeout = keepaliveTimeout; this.idleTimeout = idleTimeout; - this.sslContextBuilder = - getSslContextBuilder(trustCertCollectionFilePath, keyCertChainFilePath, keyFilePath); + this.certContext = + new OpenSslContext(trustCertCollectionFilePath, keyCertChainFilePath, keyFilePath); } public ChannelFactory( @@ -79,54 +206,33 @@ public class ChannelFactory implements AutoCloseable { int idleTimeout, String jksKeyPath, String jksKeyPassword, - String jkstrustPath, + String jksTrustPath, String jksTrustPassword) { this.maxFrameSize = maxFrameSize; this.keepaliveTime = keepaliveTime; this.keepaliveTimeout = keepaliveTimeout; this.idleTimeout = idleTimeout; - this.sslContextBuilder = - getSslContextBuilder(jksKeyPath, jksKeyPassword, jkstrustPath, jksTrustPassword); + this.certContext = new JksContext(jksKeyPath, jksKeyPassword, jksTrustPath, jksTrustPassword); } - private SslContextBuilder getSslContextBuilder( - String jksKeyPath, String jksKeyPassword, String jksTrustPath, String jksTrustPassword) { - SslContextBuilder builder = GrpcSslContexts.forClient(); - try { - if (jksKeyPath != null && jksKeyPassword != null) { - KeyStore keyStore = KeyStore.getInstance("JKS"); - keyStore.load(new FileInputStream(jksKeyPath), jksKeyPassword.toCharArray()); - KeyManagerFactory keyManagerFactory = - KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); - keyManagerFactory.init(keyStore, jksKeyPassword.toCharArray()); - builder.keyManager(keyManagerFactory); + @VisibleForTesting + public boolean reloadSslContext() { + if (certContext != null) { + SslContextBuilder newBuilder = certContext.reload(); + if (newBuilder != null) { + sslContextBuilder.set(newBuilder); + return true; } - if (jksTrustPath != null && jksTrustPassword != null) { - KeyStore trustStore = KeyStore.getInstance("JKS"); - trustStore.load(new FileInputStream(jksTrustPath), jksTrustPassword.toCharArray()); - TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(PUB_KEY_INFRA); - trustManagerFactory.init(trustStore); - builder.trustManager(trustManagerFactory); - } - } catch (Exception e) { - logger.error("JKS SSL context builder failed!", e); } - return builder; - } - - private SslContextBuilder getSslContextBuilder( - String trustCertCollectionFilePath, String keyCertChainFilePath, String keyFilePath) { - SslContextBuilder builder = GrpcSslContexts.forClient(); - if (trustCertCollectionFilePath != null) { - builder.trustManager(new File(trustCertCollectionFilePath)); - } - if (keyCertChainFilePath != null && keyFilePath != null) { - builder.keyManager(new File(keyCertChainFilePath), new File(keyFilePath)); - } - return builder; + return false; } public ManagedChannel getChannel(String addressStr, HostMapping hostMapping) { + if (reloadSslContext()) { + logger.info("invalidate connection pool"); + connPool.clear(); + } + return connPool.computeIfAbsent( addressStr, key -> { @@ -153,12 +259,12 @@ public class ChannelFactory implements AutoCloseable { .keepAliveWithoutCalls(true) .idleTimeout(idleTimeout, TimeUnit.SECONDS); - if (sslContextBuilder == null) { + if (certContext == null) { return builder.usePlaintext().build(); } else { - SslContext sslContext = null; + SslContext sslContext; try { - sslContext = sslContextBuilder.build(); + sslContext = sslContextBuilder.get().build(); } catch (SSLException e) { logger.error("create ssl context failed!", e); return null; diff --git a/src/test/java/org/tikv/common/ChannelFactoryTest.java b/src/test/java/org/tikv/common/ChannelFactoryTest.java new file mode 100644 index 0000000000..6296bf5a30 --- /dev/null +++ b/src/test/java/org/tikv/common/ChannelFactoryTest.java @@ -0,0 +1,43 @@ +/* + * Copyright 2022 TiKV Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.tikv.common; + +import static org.junit.Assert.assertTrue; + +import java.io.File; +import org.junit.Test; +import org.tikv.common.util.ChannelFactory; + +public class ChannelFactoryTest { + @Test + public void testTlsReload() { + final int v = 1024; + String tlsPath = "src/test/resources/tls/"; + String caPath = tlsPath + "ca.crt"; + String clientCertPath = tlsPath + "client.crt"; + String clientKeyPath = tlsPath + "client.pem"; + ChannelFactory factory = new ChannelFactory(v, v, v, v, caPath, clientCertPath, clientKeyPath); + HostMapping mapping = uri -> uri; + + factory.getChannel("127.0.0.1:2379", mapping); + + assertTrue(new File(clientKeyPath).setLastModified(System.currentTimeMillis())); + + assertTrue(factory.reloadSslContext()); + } +} diff --git a/src/test/resources/tls/ca.crt b/src/test/resources/tls/ca.crt new file mode 100644 index 0000000000..4b882f9065 --- /dev/null +++ b/src/test/resources/tls/ca.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDLzCCAhegAwIBAgIQfjsrMhS4NjlriGvuRi3CsjANBgkqhkiG9w0BAQsFADAh +MRAwDgYDVQQKEwdQaW5nQ0FQMQ0wCwYDVQQLEwRUaVVQMCAXDTIyMDQxODA3NTYx +NFoYDzIwNzIwNDA1MDc1NjE0WjAhMRAwDgYDVQQKEwdQaW5nQ0FQMQ0wCwYDVQQL +EwRUaVVQMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAw5KnSdBoz0cg +CHGq0yNmwZ/7XkZNLehK5jtXtWdwwVPSN5Bc1Q+7vmEO3ObKhRsBonIPEqiOJk9Z +jE6/vSihH1vzz7Qs9BUmdFH4S4FLpRIRcuvNpdZzLanTMe2FNt0c16bBWgyvYiw+ +PdTom7HrWaUZIBGDzUKrH3ORPNm7dJL96vadPeH4WbZGGPL+k0CNCcdoESkBRNjL +eR48GcvWiq1o5o2nY5GE4lSiVgw+CWE+vl6DFuM2/z2acFa3mz+zDz/yL1RM9xfX +PQXBfkbidaAhKb4+8Gn6srh7ZlA5tqd7z4Tb+1JWNg9JULxr+sCIodgu1M5BlDuW +SrPtn+1UnQIDAQABo2EwXzAOBgNVHQ8BAf8EBAMCAoQwHQYDVR0lBBYwFAYIKwYB +BQUHAwIGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFFxX8EGF +b9O0IVshObPv1ytKjYLbMA0GCSqGSIb3DQEBCwUAA4IBAQAJlYocp2k0IUK7V3la +gGFLmY3xKxuNfiEG6B1Uf3DRpo2h/MLq9ndEkBpzcOkFCv34QFwxU2Au3NKr6HEp +9y1LovwomCRHcg2VL5sa/pJl2onTb4fJ8yY2ER5qYg95qMVf1o4bUD6aGfOFGaby +uChJ4b6gw3SyWIYVbGORiuK4lxxmkaqjUlXvE5tOwzH1XLP2NoFX+g+kIRBH12jb +TJD8/ykWLmGuXezRk3JmptVP1K/0TtO+8NtFgUmRb10DZgNZY+6qQf+gsGW1e5Ow +unfFXlW2QqxTSnaZKDXlT3Gjz161yX8pTi48j5Hrs3mKDejP/3b/E2f9Cg34EZ/V +hmF8 +-----END CERTIFICATE----- diff --git a/src/test/resources/tls/client.crt b/src/test/resources/tls/client.crt new file mode 100644 index 0000000000..3b0b925b88 --- /dev/null +++ b/src/test/resources/tls/client.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDXDCCAkSgAwIBAgIRAPK1hi2T9x5tGKa9bEzlioEwDQYJKoZIhvcNAQELBQAw +ITEQMA4GA1UEChMHUGluZ0NBUDENMAsGA1UECxMEVGlVUDAeFw0yMjA0MTgwNzU2 +MTRaFw0zMjA0MTUwNzU2MTRaMF4xEDAOBgNVBAoTB1BpbmdDQVAxKTALBgNVBAsT +BFRpVVAwGgYDVQQLExN0aXVwLWNsdXN0ZXItY2xpZW50MR8wHQYDVQQDExZpb3Nt +YW50aHVzLW1pbmktY2xpZW50MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC +AQEAtF3wrFgmHzWqOO5Rk5z3qaIuMUpSTu4aOiTIaucgJF8/Bai00l2ELIgA3WF1 +/yijRPUiLl6Z9A+PEQ6Yg6n1qCdNtcbHYUbJiq+b7WLCsLWWpLDCpE163t70/QkO +kGweHzUqABiFSYqN+aUmJsfgcP+XpuTLYgfgm2IoxW1TrGG2CuFwe9GQvxwpVP2O +r4O3EZF4ERUmbjTfbLzxIlAJGUFGXqVk5ucIavxvRNJoXyMzufYMOt6ZStuVXDP9 +j4M37QAhyx1A9Pn4GA5mgtkLNTwFWQQUhhFHl6qeferhsNIOAk8tqLBHiIdN6vlC +fixlyVlI32Qo6dnFkhwJZulGKwIDAQABo1IwUDAOBgNVHQ8BAf8EBAMCBaAwHQYD +VR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMB8GA1UdIwQYMBaAFFxX8EGFb9O0 +IVshObPv1ytKjYLbMA0GCSqGSIb3DQEBCwUAA4IBAQCOM2ehBIyXxAzrNwiQEfp7 +19Fum01UalaNtCcuU7l33mA8bPIVMLB82oGTt6RwF/hBerDb08zTRIWxpaMy1DuV +4nD/DlFWW5Q2G066cXlpH/tFzwa3BEf0NVZhkYG8XygfGkUbgUi9w8iGGsOBzpWk +I8gzTPoUPxNcI8yzTTSF5LPvwCrEym0K7N+8ZAHflNu3PnnzDRuXA2z2bcXjjWKm +GGgYwh3TXt5DMJBtEQ0tbB/FLUr9uSS4GONLxzf1pWOXeFWXCjr8KXeWLjeAWfJl +DIXViXSBoJhhlerwliwIq6lbP6diD3PZdj/RJTm1S3rWFoJVbhgIkBKu7NpZp11F +-----END CERTIFICATE----- diff --git a/src/test/resources/tls/client.pem b/src/test/resources/tls/client.pem new file mode 100644 index 0000000000..3de7371d1d --- /dev/null +++ b/src/test/resources/tls/client.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQC0XfCsWCYfNao4 +7lGTnPepoi4xSlJO7ho6JMhq5yAkXz8FqLTSXYQsiADdYXX/KKNE9SIuXpn0D48R +DpiDqfWoJ021xsdhRsmKr5vtYsKwtZaksMKkTXre3vT9CQ6QbB4fNSoAGIVJio35 +pSYmx+Bw/5em5MtiB+CbYijFbVOsYbYK4XB70ZC/HClU/Y6vg7cRkXgRFSZuNN9s +vPEiUAkZQUZepWTm5whq/G9E0mhfIzO59gw63plK25VcM/2PgzftACHLHUD0+fgY +DmaC2Qs1PAVZBBSGEUeXqp596uGw0g4CTy2osEeIh03q+UJ+LGXJWUjfZCjp2cWS +HAlm6UYrAgMBAAECggEAGpdW6jG8vREuXWJVSIv1v16XrNCmPdjAqR3PJmOYy4P3 +SKBMuE7tM5uBdSHvQYT+PSZeubNcwyygDQW32oFuJDJXNJtvvZmwEPA+7sqGpYLA +CNu+dnatzLnWKI/zQ7uM3VD7NjRnQiZZNBry+viw0Df+Za6JhZRfusjH9gNeOKWX +yO+gjyUbV4hQkPsX3DCOuuVzVtlHjzHXBrCbm9XWQVgt6nC9lhsF+F4LxGwikYn2 +sUyZ3ZCshSntoI3mpzNxH73J72PnpRrIbUjBGjR6hB2pRtHtYkVr1JndNVk00MPg +P7Bi2JAKQ0dgQnYB8tUAlyhfnmY9NaM6Ec3evq5n0QKBgQDPeZS6xEdbkIiRmLNG +I3rEweQCu5Ibc5LxureJc+JA0d8wWN/Hr/lvA+NKws0TozVQou0lz50wcQbmIrue +8M+uZjmKLfOld9l6YQLTzEbrKGw9vL2qtBIwxU1cFw6JaKyk0dpNI2nkkV/2ugJc +2CBop9xtoSsoPY3a50D52O7i3wKBgQDejUmz5RFtmbvPcqQ9cBj6fIDVZiJ7d3FG +3YlYTV0kBzMMrgBT2jsEaGwtFH2lxCD24Ax4/OnrRCwLu/JgSVSD++o1+Rs4KB4s +AD9jXd/zNC3oc8IJQ+ft6Xn8UMCgTCe3NCYry7rJQZ2hAx68SxbC46yv0qeNa7BX +sh35VjNyNQKBgHglsDt37trXmD64bxju8ul+Xsw2UyYSh8X0mtS+hweCgf09elnp +Tkk7tyRUKu50VSudLjf3QtAKpDQhaQVh7uLP0AJ1GeN4xDhadYixg2AqyIP4CN4R +6XbUyzfJImHwfAn2fLSvDWOPzELU9QlPH3V7v+q8qoFjJALgaIBHYA+BAoGADv/U +xNQefZWL6+pdGWrxtAgqIrfUgR/GubD6rcHhEimODj+38+7UZXKoP82OvlpeomTt +UkYxedLJaS0Mo+KtWIvk+ChG5l0F049ctlTAYELXCUCsBjXWbtl6iD/lC6i2UImq +PO9pMmFCv3RXYPdqnE39+IepFUX5x59Ql9pwczUCgYBECYnFSH7IiewhL2GDsV8n +kKNIWEBfK04PBLAtqTZGGo2msZ8ysXaaeePs8STtPlUk7rEE/6GYgjhDvbOlXJEx +QoGX8knDh/8+itYlxdWZGriliZl9vdZ4PDaoMvLsYDlhhrEP4YYKjh/nf1Y5WYeG +XhheSjlbxT9gBvagCRSitg== +-----END PRIVATE KEY-----