diff --git a/src/common/security.rs b/src/common/security.rs index 483759c..f8430e7 100644 --- a/src/common/security.rs +++ b/src/common/security.rs @@ -11,6 +11,7 @@ use regex::Regex; use tonic::transport::Certificate; use tonic::transport::Channel; use tonic::transport::ClientTlsConfig; +use tonic::transport::Endpoint; use tonic::transport::Identity; use crate::internal_err; @@ -77,28 +78,42 @@ impl SecurityManager { where Factory: FnOnce(Channel) -> Client, { - let addr = "http://".to_string() + &SCHEME_REG.replace(addr, ""); - info!("connect to rpc server at endpoint: {:?}", addr); - let mut builder = Channel::from_shared(addr)? - .tcp_keepalive(Some(Duration::from_secs(10))) - .keep_alive_timeout(Duration::from_secs(3)); - - if !self.ca.is_empty() { - let tls = ClientTlsConfig::new() - .ca_certificate(Certificate::from_pem(&self.ca)) - .identity(Identity::from_pem( - &self.cert, - load_pem_file("private key", &self.key)?, - )); - builder = builder.tls_config(tls)?; + let channel = if !self.ca.is_empty() { + self.tls_channel(addr).await? + } else { + self.default_channel(addr).await? }; - - let ch = builder.connect().await?; + let ch = channel.connect().await?; Ok(factory(ch)) } + + async fn tls_channel(&self, addr: &str) -> Result { + let addr = "https://".to_string() + &SCHEME_REG.replace(addr, ""); + let builder = self.endpoint(addr.to_string())?; + let tls = ClientTlsConfig::new() + .ca_certificate(Certificate::from_pem(&self.ca)) + .identity(Identity::from_pem( + &self.cert, + load_pem_file("private key", &self.key)?, + )); + let builder = builder.tls_config(tls)?; + Ok(builder) + } + + async fn default_channel(&self, addr: &str) -> Result { + let addr = "http://".to_string() + &SCHEME_REG.replace(addr, ""); + self.endpoint(addr) + } + + fn endpoint(&self, addr: String) -> Result { + let endpoint = Channel::from_shared(addr)? + .tcp_keepalive(Some(Duration::from_secs(10))) + .keep_alive_timeout(Duration::from_secs(3)); + Ok(endpoint) + } } #[cfg(test)]