diff --git a/proxy/src/bind.rs b/proxy/src/bind.rs index cfa7ad8f7..dd7161ba7 100644 --- a/proxy/src/bind.rs +++ b/proxy/src/bind.rs @@ -63,6 +63,10 @@ where protocol: Protocol, } +// `Bind` cannot use `ConditionalConnectionConfig` since it uses a +// `tls::Identity` and a `tls::ClientConfig` obtained from different sources. +pub type ConditionalTlsClientConfig = Conditional; + /// A type of service binding. /// /// Some services, for various reasons, may not be able to be used to serve multiple @@ -75,7 +79,7 @@ where B: tower_h2::Body + Send + 'static, ::Buf: Send, { - Bound(WatchService, RebindTls>), + Bound(WatchService>), BindsPerRequest { // When `poll_ready` is called, the _next_ service to be used may be bound // ahead-of-time. This stack is used only to serve the next request to this @@ -132,7 +136,7 @@ pub struct RebindTls { pub type Service = BoundService; -pub type Stack = WatchService, RebindTls>; +pub type Stack = WatchService>; type StackInner = Reconnect>>; @@ -226,26 +230,24 @@ where /// /// When the TLS client configuration is invalidated, this function will /// be called again to bind a new stack. - fn bind_inner_stack(&self, ep: &Endpoint, protocol: &Protocol)-> StackInner { + fn bind_inner_stack( + &self, + ep: &Endpoint, + protocol: &Protocol, + tls_client_config: &ConditionalTlsClientConfig, + )-> StackInner { debug!("bind_inner_stack endpoint={:?}, protocol={:?}", ep, protocol); let addr = ep.address(); - // Like `tls::current_connection_config()` with optional identity. - let tls = match ep.tls_identity() { - Conditional::Some(identity) => { - // TODO: the watch should be an explicit field of `Bind`, rather - // than passed in the context. - match *self.ctx.tls_client_config_watch().borrow() { - Some(ref config) => - Conditional::Some(tls::ConnectionConfig { - identity: identity.clone(), - config: config.clone() - }), - None => Conditional::None(tls::ReasonForNoTls::NoConfig), + // Like `tls::current_connection_config()`. + let tls = ep.tls_identity().and_then(|identity| { + tls_client_config.as_ref().map(|config| { + tls::ConnectionConfig { + identity: identity.clone(), + config: config.clone(), } - }, - Conditional::None(why_no_identity) => Conditional::None(why_no_identity.into()), - }; + }) + }); let client_ctx = ctx::transport::Client::new( &self.ctx, @@ -307,8 +309,8 @@ where }; // TODO: the watch should be an explicit field of `Bind`, rather // than passed in the context. - let tls_client_cfg = self.ctx.tls_client_config_watch().clone(); - WatchService::new(tls_client_cfg, rebind) + let tls_client_config = self.ctx.tls_client_config_watch().clone(); + WatchService::new(tls_client_config, rebind) } pub fn bind_service(&self, ep: &Endpoint, protocol: &Protocol) -> BoundService { @@ -578,19 +580,17 @@ impl Protocol { // ===== impl RebindTls ===== -impl Rebind> for RebindTls +impl Rebind for RebindTls where B: tower_h2::Body + Send + 'static, ::Buf: Send, { type Service = StackInner; - fn rebind(&mut self, _cfg: &Option) -> Self::Service { + fn rebind(&mut self, tls: &ConditionalTlsClientConfig) -> Self::Service { debug!( "rebinding endpoint stack for {:?}:{:?} on TLS config change", self.endpoint, self.protocol, ); - // We don't actually pass in the config, as `self.bind` also already - // owns a config watch of its own. - self.bind.bind_inner_stack(&self.endpoint, &self.protocol) + self.bind.bind_inner_stack(&self.endpoint, &self.protocol, tls) } } diff --git a/proxy/src/conditional.rs b/proxy/src/conditional.rs index 431b42a46..b1ccc4ec6 100644 --- a/proxy/src/conditional.rs +++ b/proxy/src/conditional.rs @@ -22,7 +22,7 @@ where impl std::fmt::Debug for Conditional where C: Clone + std::fmt::Debug, - R: Clone + std::fmt::Debug + R: Clone + std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { match self { @@ -72,10 +72,33 @@ where C: Clone, R: Copy + Clone, { + pub fn and_then(self, f: F) -> Conditional + where + CR: Clone, + R: Into, + RR: Clone, + F: FnOnce(C) -> Conditional, + { + match self { + Conditional::Some(c) => f(c), + Conditional::None(r) => Conditional::None(r.into()), + } + } + pub fn as_ref<'a>(&'a self) -> Conditional<&'a C, R> { match self { Conditional::Some(c) => Conditional::Some(&c), Conditional::None(r) => Conditional::None(*r), } } + + pub fn map(self, f: F) -> Conditional + where + CR: Clone, + R: Into, + RR: Clone, + F: FnOnce(C) -> CR, + { + self.and_then(|c| Conditional::Some(f(c))) + } } diff --git a/proxy/src/ctx/transport.rs b/proxy/src/ctx/transport.rs index 47c46bb1d..edbeb9233 100644 --- a/proxy/src/ctx/transport.rs +++ b/proxy/src/ctx/transport.rs @@ -42,10 +42,7 @@ impl TlsStatus { pub fn from(c: &Conditional) -> Self where C: Clone + std::fmt::Debug { - match c { - Conditional::Some(_) => Conditional::Some(()), - Conditional::None(r) => Conditional::None(*r), - } + c.as_ref().map(|_| ()) } } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index bb834b36e..08288d24e 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -264,13 +264,12 @@ where config.inbound_router_capacity, config.inbound_router_max_idle_age, ); - let tls_settings = match &config.tls_settings { - Conditional::Some(settings) => Conditional::Some(tls::ConnectionConfig { + let tls_settings = config.tls_settings.as_ref().map(|settings| { + tls::ConnectionConfig { identity: settings.service_identity.clone(), config: tls_server_config - }), - Conditional::None(r) => Conditional::None(*r), - }; + } + }); serve( inbound_listener, tls_settings, diff --git a/proxy/src/transport/tls/config.rs b/proxy/src/transport/tls/config.rs index 43a3ac507..0501d9c20 100644 --- a/proxy/src/transport/tls/config.rs +++ b/proxy/src/transport/tls/config.rs @@ -72,8 +72,8 @@ impl std::fmt::Debug for ClientConfig { #[derive(Clone)] pub struct ServerConfig(pub(super) Arc); -pub type ClientConfigWatch = Watch>; -pub type ServerConfigWatch = Watch>; +pub type ClientConfigWatch = Watch>; +pub type ServerConfigWatch = Watch>; /// The configuration in effect for a client (`ClientConfig`) or server /// (`ServerConfig`) TLS connection. @@ -267,15 +267,15 @@ pub fn watch_for_config_changes(settings: Conditional<&CommonSettings, ReasonFor let settings = if let Conditional::Some(settings) = settings { settings.clone() } else { - let (client_watch, _) = Watch::new(None); - let (server_watch, _) = Watch::new(None); + let (client_watch, _) = Watch::new(Conditional::None(ReasonForNoTls::NoConfig)); + let (server_watch, _) = Watch::new(Conditional::None(ReasonForNoTls::NoConfig)); let no_future = future::empty(); return (client_watch, server_watch, Box::new(no_future)); }; let changes = settings.stream_changes(Duration::from_secs(1)); - let (client_watch, client_store) = Watch::new(None); - let (server_watch, server_store) = Watch::new(None); + let (client_watch, client_store) = Watch::new(Conditional::None(ReasonForNoTls::NoConfig)); + let (server_watch, server_store) = Watch::new(Conditional::None(ReasonForNoTls::NoConfig)); // `Store::store` will return an error iff all watchers have been dropped, // so we'll use `fold` to cancel the forwarding future. Eventually, we can @@ -286,10 +286,10 @@ pub fn watch_for_config_changes(settings: Conditional<&CommonSettings, ReasonFor (client_store, server_store), |(mut client_store, mut server_store), ref config| { client_store - .store(Some(ClientConfig::from(config))) + .store(Conditional::Some(ClientConfig::from(config))) .map_err(|_| trace!("all client config watchers dropped"))?; server_store - .store(Some(ServerConfig::from(config))) + .store(Conditional::Some(ServerConfig::from(config))) .map_err(|_| trace!("all server config watchers dropped"))?; Ok((client_store, server_store)) }) @@ -332,7 +332,7 @@ impl ClientConfig { /// `ClientConfigWatch`. We can't use `#[cfg(test)]` here because the /// benchmarks use this. pub fn no_tls() -> ClientConfigWatch { - let (watch, _) = Watch::new(None); + let (watch, _) = Watch::new(Conditional::None(ReasonForNoTls::NoConfig)); watch } } @@ -362,22 +362,18 @@ impl ServerConfig { } } -pub fn current_connection_config(watch: &ConditionalConnectionConfig>>) +pub fn current_connection_config( + watch: &ConditionalConnectionConfig>>) -> ConditionalConnectionConfig where C: Clone { - match watch { - Conditional::Some(c) => { - match *c.config.borrow() { - Some(ref config) => - Conditional::Some(ConnectionConfig { - identity: c.identity.clone(), - config: config.clone() - }), - None => Conditional::None(ReasonForNoTls::NoConfig), + watch.as_ref().and_then(|c| { + c.config.borrow().as_ref().map(|config| { + ConnectionConfig { + identity: c.identity.clone(), + config: config.clone() } - }, - Conditional::None(r) => Conditional::None(*r), - } + }) + }) } fn load_file_contents(path: &PathBuf) -> Result, Error> {