From 73e96ddb12712d476f35fe08ea825db64da91e77 Mon Sep 17 00:00:00 2001 From: Zahari Dichev Date: Fri, 18 Oct 2024 14:21:26 +0300 Subject: [PATCH] feat(outbound): Route TLS connections by SNI (#3160) This is a draft PR that wires together some of the recently introduced changes to the proxy in order to deliver TLS routing functionality. The changes are: This change adds a TLS routing stack that has the following properties: - it always expects that `SNI` is present - uses tls routing information provided by the discovery API to perform routing based on SNI - does not concern itself with terminating TLS by simply proxies the encrypted stream Signed-off-by: Zahari Dichev --- Cargo.lock | 2 + linkerd/app/outbound/Cargo.toml | 2 + linkerd/app/outbound/src/http/concrete.rs | 7 +- linkerd/app/outbound/src/lib.rs | 5 +- linkerd/app/outbound/src/metrics.rs | 9 +- linkerd/app/outbound/src/protocol.rs | 18 +- linkerd/app/outbound/src/sidecar.rs | 79 +++- linkerd/app/outbound/src/tls.rs | 134 ++++++ linkerd/app/outbound/src/tls/concrete.rs | 387 ++++++++++++++++++ linkerd/app/outbound/src/tls/logical.rs | 112 +++++ linkerd/app/outbound/src/tls/logical/route.rs | 100 +++++ .../app/outbound/src/tls/logical/router.rs | 215 ++++++++++ linkerd/app/outbound/src/tls/logical/tests.rs | 213 ++++++++++ .../outbound/src/tls/logical/tests/basic.rs | 73 ++++ linkerd/io/src/sensor.rs | 9 +- linkerd/tls/src/detect_sni.rs | 107 ----- linkerd/tls/src/lib.rs | 6 +- linkerd/tls/src/server.rs | 4 + linkerd/tls/src/server/required_sni.rs | 118 ++++++ 19 files changed, 1480 insertions(+), 120 deletions(-) create mode 100644 linkerd/app/outbound/src/tls.rs create mode 100644 linkerd/app/outbound/src/tls/concrete.rs create mode 100644 linkerd/app/outbound/src/tls/logical.rs create mode 100644 linkerd/app/outbound/src/tls/logical/route.rs create mode 100644 linkerd/app/outbound/src/tls/logical/router.rs create mode 100644 linkerd/app/outbound/src/tls/logical/tests.rs create mode 100644 linkerd/app/outbound/src/tls/logical/tests/basic.rs delete mode 100644 linkerd/tls/src/detect_sni.rs create mode 100644 linkerd/tls/src/server/required_sni.rs diff --git a/Cargo.lock b/Cargo.lock index 37e322a8d..a23d0443f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1341,6 +1341,7 @@ dependencies = [ "linkerd-proxy-client-policy", "linkerd-retry", "linkerd-stack", + "linkerd-tls-route", "linkerd-tonic-stream", "linkerd-tonic-watch", "linkerd-tracing", @@ -1351,6 +1352,7 @@ dependencies = [ "prometheus-client", "thiserror", "tokio", + "tokio-rustls", "tokio-test", "tonic", "tower", diff --git a/linkerd/app/outbound/Cargo.toml b/linkerd/app/outbound/Cargo.toml index 5a2a1c981..97b2ca733 100644 --- a/linkerd/app/outbound/Cargo.toml +++ b/linkerd/app/outbound/Cargo.toml @@ -44,12 +44,14 @@ linkerd-proxy-client-policy = { path = "../../proxy/client-policy", features = [ "proto", ] } linkerd-retry = { path = "../../retry" } +linkerd-tls-route = { path = "../../tls/route" } linkerd-tonic-stream = { path = "../../tonic-stream" } linkerd-tonic-watch = { path = "../../tonic-watch" } [dev-dependencies] hyper = { version = "0.14", features = ["http1", "http2"] } tokio = { version = "1", features = ["macros", "sync", "time"] } +tokio-rustls = "0.24" tokio-test = "0.4" tower-test = "0.4" diff --git a/linkerd/app/outbound/src/http/concrete.rs b/linkerd/app/outbound/src/http/concrete.rs index 90ec48988..ad7a263be 100644 --- a/linkerd/app/outbound/src/http/concrete.rs +++ b/linkerd/app/outbound/src/http/concrete.rs @@ -33,10 +33,13 @@ pub use self::balance::BalancerMetrics; pub enum Dispatch { Balance(NameAddr, EwmaConfig), Forward(Remote, Metadata), - Fail { message: Arc }, + /// A backend dispatcher that explicitly fails all requests. + Fail { + message: Arc, + }, } -/// A backend dispatcher explicitly fails all requests. +/// A backend dispatcher that explicitly fails all requests. #[derive(Debug, thiserror::Error)] #[error("{0}")] pub struct DispatcherFailed(Arc); diff --git a/linkerd/app/outbound/src/lib.rs b/linkerd/app/outbound/src/lib.rs index ac11a54a6..dd96a3c64 100644 --- a/linkerd/app/outbound/src/lib.rs +++ b/linkerd/app/outbound/src/lib.rs @@ -21,7 +21,7 @@ use linkerd_app_core::{ tap, }, svc::{self, ServiceExt}, - tls, + tls::ConnectMeta as TlsConnectMeta, transport::addrs::*, AddrMatch, Error, ProxyRuntime, }; @@ -46,6 +46,7 @@ mod sidecar; pub mod tcp; #[cfg(any(test, feature = "test-util"))] pub mod test_util; +pub mod tls; mod zone; pub use self::discover::{spawn_synthesized_profile_policy, synthesize_forward_policy, Discovery}; @@ -100,7 +101,7 @@ struct Runtime { drain: drain::Watch, } -pub type ConnectMeta = tls::ConnectMeta>; +pub type ConnectMeta = TlsConnectMeta>; /// A reference to a frontend/apex resource, usually a service. #[derive(Clone, Debug, PartialEq, Eq, Hash)] diff --git a/linkerd/app/outbound/src/metrics.rs b/linkerd/app/outbound/src/metrics.rs index 7f0452e60..c1a8518e0 100644 --- a/linkerd/app/outbound/src/metrics.rs +++ b/linkerd/app/outbound/src/metrics.rs @@ -37,6 +37,7 @@ pub struct OutboundMetrics { pub(crate) struct PromMetrics { pub(crate) http: crate::http::HttpMetrics, pub(crate) opaq: crate::opaq::OpaqMetrics, + pub(crate) tls: crate::tls::TlsMetrics, pub(crate) zone: crate::zone::TcpZoneMetrics, } @@ -92,8 +93,14 @@ impl PromMetrics { let opaq = crate::opaq::OpaqMetrics::register(registry.sub_registry_with_prefix("tcp")); let zone = crate::zone::TcpZoneMetrics::register(registry.sub_registry_with_prefix("tcp")); + let tls = crate::tls::TlsMetrics::register(registry.sub_registry_with_prefix("tls")); - Self { http, opaq, zone } + Self { + http, + opaq, + tls, + zone, + } } } diff --git a/linkerd/app/outbound/src/protocol.rs b/linkerd/app/outbound/src/protocol.rs index b9904cd42..4af2a9271 100644 --- a/linkerd/app/outbound/src/protocol.rs +++ b/linkerd/app/outbound/src/protocol.rs @@ -15,6 +15,7 @@ pub enum Protocol { Http2, Detect, Opaque, + Tls, } // === impl Outbound === @@ -29,6 +30,7 @@ impl Outbound { pub fn push_protocol( self, http: svc::ArcNewCloneHttp>, + tls: svc::ArcNewCloneTcp>>, ) -> Outbound> where // Target type indicating whether detection should be skipped. @@ -83,7 +85,14 @@ impl Outbound { http.map_stack(|_, _, http| { // First separate traffic that needs protocol detection. Then switch // between traffic that is known to be HTTP or opaque. - http.push_switch(Ok::<_, Infallible>, opaq.clone().into_inner()) + let known = http.push_switch( + Ok::<_, Infallible>, + opaq.clone() + .push_switch(Ok::<_, Infallible>, tls.clone()) + .into_inner(), + ); + + known .push_on_service(svc::MapTargetLayer::new(io::EitherIo::Left)) .push_switch( |parent: T| -> Result<_, Infallible> { @@ -96,7 +105,12 @@ impl Outbound { version: http::Version::H2, parent, }))), - Protocol::Opaque => Ok(svc::Either::A(svc::Either::B(parent))), + Protocol::Opaque => { + Ok(svc::Either::A(svc::Either::B(svc::Either::A(parent)))) + } + Protocol::Tls => { + Ok(svc::Either::A(svc::Either::B(svc::Either::B(parent)))) + } Protocol::Detect => Ok(svc::Either::B(parent)), } }, diff --git a/linkerd/app/outbound/src/sidecar.rs b/linkerd/app/outbound/src/sidecar.rs index 25a046474..57278d6b3 100644 --- a/linkerd/app/outbound/src/sidecar.rs +++ b/linkerd/app/outbound/src/sidecar.rs @@ -1,7 +1,7 @@ use crate::{ http, opaq, policy, protocol::{self, Protocol}, - Discovery, Outbound, ParentRef, + tls, Discovery, Outbound, ParentRef, }; use linkerd_app_core::{ io, profiles, @@ -32,6 +32,12 @@ struct HttpSidecar { routes: watch::Receiver, } +#[derive(Clone, Debug)] +struct TlsSidecar { + orig_dst: OrigDstAddr, + routes: watch::Receiver, +} + // === impl Outbound === impl Outbound<()> { @@ -53,6 +59,12 @@ impl Outbound<()> { R::Resolution: Unpin, { let opaq = self.to_tcp_connect().push_opaq_cached(resolve.clone()); + let tls = self + .to_tcp_connect() + .push_tls_cached(resolve.clone()) + .into_stack() + .push_map_target(TlsSidecar::from) + .arc_new_clone_tcp(); let http = self .to_tcp_connect() @@ -64,7 +76,8 @@ impl Outbound<()> { .push_map_target(HttpSidecar::from) .arc_new_clone_http(); - opaq.push_protocol(http.into_inner()) + opaq.clone() + .push_protocol(http.into_inner(), tls.into_inner()) // Use a dedicated target type to bind discovery results to the // outbound sidecar stack configuration. .map_stack(move |_, _, stk| stk.push_map_target(Sidecar::from)) @@ -131,7 +144,8 @@ impl svc::Param for Sidecar { match self.policy.borrow().protocol { policy::Protocol::Http1(_) => Protocol::Http1, policy::Protocol::Http2(_) | policy::Protocol::Grpc(_) => Protocol::Http2, - policy::Protocol::Opaque(_) | policy::Protocol::Tls(_) => Protocol::Opaque, + policy::Protocol::Opaque(_) => Protocol::Opaque, + policy::Protocol::Tls(_) => Protocol::Tls, policy::Protocol::Detect { .. } => Protocol::Detect, } } @@ -326,3 +340,62 @@ impl std::hash::Hash for HttpSidecar { self.version.hash(state); } } + +// === impl TlsSidecar === + +impl From for TlsSidecar { + fn from(parent: Sidecar) -> Self { + let orig_dst = parent.orig_dst; + let mut policy = parent.policy.clone(); + + let init = Self::mk_policy_routes(orig_dst, &policy.borrow_and_update()) + .expect("initial policy must be tls"); + let routes = tls::spawn_routes(policy, init, move |policy: &policy::ClientPolicy| { + Self::mk_policy_routes(orig_dst, policy) + }); + TlsSidecar { orig_dst, routes } + } +} + +impl TlsSidecar { + fn mk_policy_routes( + OrigDstAddr(orig_dst): OrigDstAddr, + policy: &policy::ClientPolicy, + ) -> Option { + let parent_ref = ParentRef(policy.parent.clone()); + let routes = match policy.protocol { + policy::Protocol::Tls(policy::tls::Tls { ref routes }) => routes.clone(), + _ => { + tracing::info!("Ignoring a discovery update that changed a route from TLS"); + return None; + } + }; + + Some(tls::Routes { + addr: orig_dst.into(), + meta: parent_ref, + routes, + backends: policy.backends.clone(), + }) + } +} + +impl svc::Param> for TlsSidecar { + fn param(&self) -> watch::Receiver { + self.routes.clone() + } +} + +impl std::cmp::PartialEq for TlsSidecar { + fn eq(&self, other: &Self) -> bool { + self.orig_dst == other.orig_dst + } +} + +impl std::cmp::Eq for TlsSidecar {} + +impl std::hash::Hash for TlsSidecar { + fn hash(&self, state: &mut H) { + self.orig_dst.hash(state); + } +} diff --git a/linkerd/app/outbound/src/tls.rs b/linkerd/app/outbound/src/tls.rs new file mode 100644 index 000000000..0802dc934 --- /dev/null +++ b/linkerd/app/outbound/src/tls.rs @@ -0,0 +1,134 @@ +use crate::{tcp, Outbound}; +use linkerd_app_core::{ + io, + metrics::prom, + proxy::{ + api_resolve::{ConcreteAddr, Metadata}, + core::Resolve, + }, + svc, + tls::{NewDetectRequiredSni, ServerName}, + transport::addrs::*, + Error, +}; +use std::{fmt::Debug, hash::Hash}; +use tokio::sync::watch; + +mod concrete; +mod logical; + +pub use self::logical::{Concrete, Routes}; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct Tls { + sni: ServerName, + parent: T, +} + +pub fn spawn_routes( + mut route_rx: watch::Receiver, + init: Routes, + mut mk: impl FnMut(&T) -> Option + Send + Sync + 'static, +) -> watch::Receiver +where + T: Send + Sync + 'static, +{ + let (tx, rx) = watch::channel(init); + + tokio::spawn(async move { + loop { + let res = tokio::select! { + biased; + _ = tx.closed() => return, + res = route_rx.changed() => res, + }; + + if res.is_err() { + // Drop the `tx` sender when the profile sender is + // dropped. + return; + } + + if let Some(routes) = (mk)(&*route_rx.borrow_and_update()) { + if tx.send(routes).is_err() { + // Drop the `tx` sender when all of its receivers are dropped. + return; + } + } + } + }); + + rx +} + +#[derive(Clone, Debug, Default)] +pub struct TlsMetrics { + balance: concrete::BalancerMetrics, +} + +// === impl Outbound === + +impl Outbound { + /// Builds a stack that proxies TLS connections. + /// + /// This stack uses caching so that a router/load-balancer may be reused + /// across multiple connections. + pub fn push_tls_cached(self, resolve: R) -> Outbound> + where + // Tls target + T: Clone + Debug + PartialEq + Eq + Hash + Send + Sync + 'static, + T: svc::Param>, + // Server-side connection + I: io::AsyncRead + io::AsyncWrite + io::PeerAddr + io::Peek, + I: Debug + Send + Sync + Unpin + 'static, + // Endpoint discovery + R: Resolve, + R::Resolution: Unpin, + // TCP endpoint stack. + C: svc::MakeConnection, Error = io::Error>, + C: Clone + Send + Sync + Unpin + 'static, + C::Connection: Send + Unpin, + C::Future: Send + Unpin, + { + self.push_tcp_endpoint() + .push_tls_concrete(resolve) + .push_tls_logical() + .map_stack(|config, _rt, stk| { + stk.push_new_idle_cached(config.discovery_idle_timeout) + // Use a dedicated target type to configure parameters for + // the TLS stack. It also helps narrow the cache key. + .push_map_target(|(sni, parent): (ServerName, T)| Tls { sni, parent }) + .push(NewDetectRequiredSni::layer( + config.proxy.detect_protocol_timeout, + )) + .arc_new_clone_tcp() + }) + } +} + +// === impl Tls === + +impl svc::Param for Tls { + fn param(&self) -> ServerName { + self.sni.clone() + } +} + +impl svc::Param> for Tls +where + T: svc::Param>, +{ + fn param(&self) -> watch::Receiver { + self.parent.param() + } +} + +// === impl TlsMetrics === + +impl TlsMetrics { + pub fn register(registry: &mut prom::Registry) -> Self { + let balance = + concrete::BalancerMetrics::register(registry.sub_registry_with_prefix("balancer")); + Self { balance } + } +} diff --git a/linkerd/app/outbound/src/tls/concrete.rs b/linkerd/app/outbound/src/tls/concrete.rs new file mode 100644 index 000000000..8fee79587 --- /dev/null +++ b/linkerd/app/outbound/src/tls/concrete.rs @@ -0,0 +1,387 @@ +use crate::{ + metrics::BalancerMetricsParams, + stack_labels, + zone::{tcp_zone_labels, TcpZoneLabels}, + BackendRef, Outbound, ParentRef, +}; +use linkerd_app_core::{ + config::QueueConfig, + drain, io, + metrics::{ + self, + prom::{self, EncodeLabelSetMut}, + OutboundZoneLocality, + }, + proxy::{ + api_resolve::{ConcreteAddr, Metadata}, + core::Resolve, + http::AuthorityOverride, + tcp::{self, balance}, + }, + svc::{self, layer::Layer}, + tls::{self, ServerName}, + transport::{self, addrs::*}, + transport_header::SessionProtocol, + Error, Infallible, NameAddr, +}; +use std::{fmt::Debug, net::SocketAddr, sync::Arc}; +use tracing::info_span; + +/// Parameter configuring dispatcher behavior. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Dispatch { + Balance(NameAddr, balance::EwmaConfig), + Forward(Remote, Metadata), + /// A backend dispatcher that explicitly fails all requests. + Fail { + message: Arc, + }, +} + +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +pub struct DispatcherFailed(Arc); + +/// Wraps errors encountered in this module. +#[derive(Debug, thiserror::Error)] +#[error("concrete service {addr}: {source}")] +pub struct ConcreteError { + addr: NameAddr, + #[source] + source: Error, +} + +/// Inner stack target type. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Endpoint { + addr: Remote, + is_local: bool, + metadata: Metadata, + parent: T, +} + +pub type BalancerMetrics = BalancerMetricsParams; + +/// A target configuring a load balancer stack. +#[derive(Clone, Debug, PartialEq, Eq)] +struct Balance { + concrete: NameAddr, + ewma: balance::EwmaConfig, + queue: QueueConfig, + parent: T, +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct ConcreteLabels { + concrete: Arc, +} + +impl prom::EncodeLabelSetMut for ConcreteLabels { + fn encode_label_set(&self, enc: &mut prom::encoding::LabelSetEncoder<'_>) -> std::fmt::Result { + use prom::encoding::EncodeLabel; + + ("concrete", &*self.concrete).encode(enc.encode_label())?; + Ok(()) + } +} + +impl prom::encoding::EncodeLabelSet for ConcreteLabels { + fn encode(&self, mut enc: prom::encoding::LabelSetEncoder<'_>) -> std::fmt::Result { + self.encode_label_set(&mut enc) + } +} + +impl svc::ExtractParam> for BalancerMetricsParams { + fn extract_param(&self, bal: &Balance) -> balance::Metrics { + self.metrics(&ConcreteLabels { + concrete: bal.concrete.to_string().into(), + }) + } +} + +// === impl Outbound === + +impl Outbound { + /// Builds a [`svc::NewService`] stack that builds buffered tls services + /// for `T`-typed concrete targets. Connections may be load balanced across + /// a discovered set of replicas or forwarded to a single endpoint, + /// depending on the value of the `Dispatch` parameter. + /// + /// When a balancer has no available inner services, it goes into + /// 'failfast'. While in failfast, buffered requests are failed and the + /// service becomes unavailable so callers may choose alternate concrete + /// services. + pub fn push_tls_concrete( + self, + resolve: R, + ) -> Outbound< + svc::ArcNewService< + T, + impl svc::Service + Clone, + >, + > + where + // Logical target + T: svc::Param, + T: Clone + Debug + Send + Sync + 'static, + T: svc::Param, + // Server-side socket. + I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static, + // Endpoint resolution. + R: Resolve, + R::Resolution: Unpin, + // Endpoint connector. + C: svc::MakeConnection> + Clone + Send + 'static, + C::Connection: Send + Unpin, + C::Metadata: Send + Unpin, + C::Future: Send, + C: Send + Sync + 'static, + { + let resolve = + svc::MapTargetLayer::new(|t: Balance| -> ConcreteAddr { ConcreteAddr(t.concrete) }) + .layer(resolve.into_service()); + + self.map_stack(|config, rt, inner| { + let queue = config.tcp_connection_queue; + + let connect = inner + .push(svc::stack::WithoutConnectionMetadata::layer()) + .push_new_thunk(); + + let forward = connect + .clone() + .push_on_service(rt.metrics.proxy.stack.layer(stack_labels("tls", "forward"))) + .instrument(|e: &Endpoint| info_span!("forward", addr = %e.addr)); + + let endpoint = connect + .push_on_service( + rt.metrics + .proxy + .stack + .layer(stack_labels("tls", "endpoint")), + ) + .instrument(|e: &Endpoint| info_span!("endpoint", addr = %e.addr)); + + let fail = svc::ArcNewService::new(|message: Arc| { + svc::mk(move |_| futures::future::ready(Err(DispatcherFailed(message.clone())))) + }); + + let inbound_ips = config.inbound_ips.clone(); + let balance = endpoint + .push_map_target( + move |((addr, metadata), target): ((SocketAddr, Metadata), Balance)| { + tracing::trace!(%addr, ?metadata, ?target, "Resolved endpoint"); + let is_local = inbound_ips.contains(&addr.ip()); + Endpoint { + addr: Remote(ServerAddr(addr)), + metadata, + is_local, + parent: target.parent, + } + }, + ) + .lift_new_with_target() + .push(tcp::NewBalance::layer( + resolve, + rt.metrics.prom.tls.balance.clone(), + )) + .push(svc::NewMapErr::layer_from_target::()) + .push_on_service(rt.metrics.proxy.stack.layer(stack_labels("tls", "balance"))) + .instrument(|t: &Balance| info_span!("balance", addr = %t.concrete)); + + balance + .push_switch(Ok::<_, Infallible>, forward.into_inner()) + .push_switch( + move |parent: T| -> Result<_, Infallible> { + Ok(match parent.param() { + Dispatch::Balance(concrete, ewma) => { + svc::Either::A(svc::Either::A(Balance { + concrete, + ewma, + queue, + parent, + })) + } + + Dispatch::Forward(addr, meta) => { + svc::Either::A(svc::Either::B(Endpoint { + addr, + is_local: false, + metadata: meta, + parent, + })) + } + Dispatch::Fail { message } => svc::Either::B(message), + }) + }, + svc::stack(fail).check_new_clone().into_inner(), + ) + .push_on_service(tcp::Forward::layer()) + .push_on_service(drain::Retain::layer(rt.drain.clone())) + .push(svc::ArcNewService::layer()) + }) + } +} + +// === impl ConcreteError === + +impl From<(&Balance, Error)> for ConcreteError { + fn from((target, source): (&Balance, Error)) -> Self { + Self { + addr: target.concrete.clone(), + source, + } + } +} + +// === impl Balance === + +impl std::ops::Deref for Balance { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.parent + } +} + +impl svc::Param for Balance { + fn param(&self) -> balance::EwmaConfig { + self.ewma + } +} + +impl svc::Param for Balance { + fn param(&self) -> svc::queue::Capacity { + svc::queue::Capacity(self.queue.capacity) + } +} + +impl svc::Param for Balance { + fn param(&self) -> svc::queue::Timeout { + svc::queue::Timeout(self.queue.failfast_timeout) + } +} + +impl> svc::Param for Balance { + fn param(&self) -> ParentRef { + self.parent.param() + } +} + +impl> svc::Param for Balance { + fn param(&self) -> BackendRef { + self.parent.param() + } +} + +// === impl Endpoint === + +impl svc::Param> for Endpoint { + fn param(&self) -> Remote { + self.addr + } +} + +impl svc::Param> for Endpoint { + fn param(&self) -> Option { + if self.is_local { + return None; + } + self.metadata + .tagged_transport_port() + .map(crate::tcp::tagged_transport::PortOverride) + } +} + +impl svc::Param> for Endpoint { + fn param(&self) -> Option { + if self.is_local { + return None; + } + self.metadata + .authority_override() + .cloned() + .map(AuthorityOverride) + } +} + +impl svc::Param> for Endpoint { + fn param(&self) -> Option { + None + } +} + +impl svc::Param for Endpoint +where + T: svc::Param, +{ + fn param(&self) -> transport::labels::Key { + transport::labels::Key::OutboundClient(self.param()) + } +} + +impl svc::Param for Endpoint +where + T: svc::Param, +{ + fn param(&self) -> metrics::OutboundEndpointLabels { + metrics::OutboundEndpointLabels { + authority: None, + labels: metrics::prefix_labels("dst", self.metadata.labels().iter()), + zone_locality: self.param(), + server_id: self.param(), + target_addr: self.addr.into(), + } + } +} + +impl svc::Param for Endpoint { + fn param(&self) -> OutboundZoneLocality { + OutboundZoneLocality::new(&self.metadata) + } +} + +impl svc::Param for Endpoint { + fn param(&self) -> TcpZoneLabels { + tcp_zone_labels(self.param()) + } +} + +impl svc::Param for Endpoint +where + T: svc::Param, +{ + fn param(&self) -> metrics::EndpointLabels { + metrics::EndpointLabels::from(svc::Param::::param(self)) + } +} + +impl svc::Param for Endpoint { + fn param(&self) -> tls::ConditionalClientTls { + if self.is_local { + return tls::ConditionalClientTls::None(tls::NoClientTls::Loopback); + } + + // If we're transporting an opaque protocol OR we're communicating with + // a gateway, then set an ALPN value indicating support for a transport + // header. + let use_transport_header = self.metadata.tagged_transport_port().is_some() + || self.metadata.authority_override().is_some(); + self.metadata + .identity() + .cloned() + .map(move |mut client_tls| { + client_tls.alpn = if use_transport_header { + use linkerd_app_core::transport_header::PROTOCOL; + Some(tls::client::AlpnProtocols(vec![PROTOCOL.into()])) + } else { + None + }; + + tls::ConditionalClientTls::Some(client_tls) + }) + .unwrap_or(tls::ConditionalClientTls::None( + tls::NoClientTls::NotProvidedByServiceDiscovery, + )) + } +} diff --git a/linkerd/app/outbound/src/tls/logical.rs b/linkerd/app/outbound/src/tls/logical.rs new file mode 100644 index 000000000..71c7bb935 --- /dev/null +++ b/linkerd/app/outbound/src/tls/logical.rs @@ -0,0 +1,112 @@ +use super::concrete; +use crate::{BackendRef, Outbound, ParentRef}; +use linkerd_app_core::{io, svc, tls::ServerName, Addr, Error}; +use linkerd_proxy_client_policy as client_policy; +use std::{fmt::Debug, hash::Hash, sync::Arc}; +use tokio::sync::watch; + +pub mod route; +pub mod router; + +#[cfg(test)] +mod tests; + +/// Indicates the address used for logical routing. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct LogicalAddr(pub Addr); + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Routes { + pub addr: Addr, + pub meta: ParentRef, + pub routes: Arc<[client_policy::tls::Route]>, + pub backends: Arc<[client_policy::Backend]>, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Concrete { + target: concrete::Dispatch, + parent: T, + parent_ref: ParentRef, + backend_ref: BackendRef, +} + +#[derive(Debug, thiserror::Error)] +#[error("no route")] +pub struct NoRoute; + +#[derive(Debug, thiserror::Error)] +#[error("logical service {addr}: {source}")] +pub struct LogicalError { + addr: Addr, + #[source] + source: Error, +} + +impl Outbound { + /// Builds a `NewService` that produces a router service for each logical + /// target. + /// + /// The router uses discovery information (provided on the target) to + /// support per-connection routing over a set of concrete inner services. + /// Only available inner services are used for routing. When there are no + /// available backends, requests are failed with a [`svc::stack::LoadShedError`]. + pub fn push_tls_logical(self) -> Outbound> + where + // Logical target. + T: svc::Param>, + T: svc::Param, + T: Eq + Hash + Clone + Debug + Send + Sync + 'static, + // Concrete stack. + I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static, + // Concrete stack. + N: svc::NewService, Service = NSvc> + Clone + Send + Sync + 'static, + NSvc: svc::Service + Clone + Send + Sync + 'static, + NSvc::Future: Send, + NSvc::Error: Into, + { + self.map_stack(|_config, _, concrete| { + concrete + .lift_new() + .push_on_service(svc::layer::mk(move |concrete: N| { + svc::stack(concrete.clone()) + .push(router::Router::layer()) + .push(svc::NewMapErr::layer_from_target::()) + .arc_new_clone_tcp() + .into_inner() + })) + // Rebuild the inner router stack every time the watch changes. + .push(svc::NewSpawnWatch::::layer_into::< + router::Router, + >()) + .arc_new_clone_tcp() + }) + } +} + +// === impl LogicalError === + +impl From<(&router::Router, Error)> for LogicalError +where + T: Eq + Hash + Clone + Debug, +{ + fn from((target, source): (&router::Router, Error)) -> Self { + let LogicalAddr(addr) = svc::Param::param(target); + Self { addr, source } + } +} + +impl svc::Param for Concrete { + fn param(&self) -> concrete::Dispatch { + self.target.clone() + } +} + +impl svc::Param for Concrete +where + T: svc::Param, +{ + fn param(&self) -> ServerName { + self.parent.param() + } +} diff --git a/linkerd/app/outbound/src/tls/logical/route.rs b/linkerd/app/outbound/src/tls/logical/route.rs new file mode 100644 index 000000000..7f46c9f28 --- /dev/null +++ b/linkerd/app/outbound/src/tls/logical/route.rs @@ -0,0 +1,100 @@ +use super::super::Concrete; +use crate::{ParentRef, RouteRef}; +use linkerd_app_core::{io, svc, Addr, Error}; +use linkerd_distribute as distribute; +use linkerd_tls_route as tls_route; +use std::{fmt::Debug, hash::Hash}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub(crate) struct Backend { + pub(crate) route_ref: RouteRef, + pub(crate) concrete: Concrete, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub(crate) struct MatchedRoute { + pub(super) r#match: tls_route::RouteMatch, + pub(super) params: Route, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub(crate) struct Route { + pub(super) parent: T, + pub(super) addr: Addr, + pub(super) parent_ref: ParentRef, + pub(super) route_ref: RouteRef, + pub(super) distribution: BackendDistribution, +} + +pub(crate) type BackendDistribution = distribute::Distribution>; +pub(crate) type NewDistribute = distribute::NewDistribute, (), N>; + +/// Wraps errors with route metadata. +#[derive(Debug, thiserror::Error)] +#[error("route {}: {source}", route.0)] +struct RouteError { + route: RouteRef, + #[source] + source: Error, +} + +// === impl Backend === + +impl Clone for Backend { + fn clone(&self) -> Self { + Self { + route_ref: self.route_ref.clone(), + concrete: self.concrete.clone(), + } + } +} + +// === impl MatchedRoute === + +impl MatchedRoute +where + // Parent target. + T: Debug + Eq + Hash, + T: Clone + Send + Sync + 'static, +{ + /// Builds a route stack that applies policy filters to requests and + /// distributes requests over each route's backends. These [`Concrete`] + /// backends are expected to be cached/shared by the inner stack. + pub(crate) fn layer( + ) -> impl svc::Layer> + Clone + where + I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static, + // Inner stack. + N: svc::NewService, Service = NSvc> + Clone + Send + Sync + 'static, + NSvc: svc::Service + Clone + Send + Sync + 'static, + NSvc::Future: Send, + NSvc::Error: Into, + { + svc::layer::mk(move |inner| { + svc::stack(inner) + .push_map_target(|t| t) + .push_map_target(|b: Backend| b.concrete) + .lift_new() + .push(NewDistribute::layer()) + // The router does not take the backend's availability into + // consideration, so we must eagerly fail requests to prevent + // leaking tasks onto the runtime. + .push_on_service(svc::LoadShed::layer()) + .push(svc::NewMapErr::layer_with(|rt: &Self| { + let route = rt.params.route_ref.clone(); + move |source| RouteError { + route: route.clone(), + source, + } + })) + .arc_new_clone_tcp() + .into_inner() + }) + } +} + +impl svc::Param> for MatchedRoute { + fn param(&self) -> BackendDistribution { + self.params.distribution.clone() + } +} diff --git a/linkerd/app/outbound/src/tls/logical/router.rs b/linkerd/app/outbound/src/tls/logical/router.rs new file mode 100644 index 000000000..5d1af9de2 --- /dev/null +++ b/linkerd/app/outbound/src/tls/logical/router.rs @@ -0,0 +1,215 @@ +use super::{ + super::{concrete, Concrete}, + route, LogicalAddr, NoRoute, +}; +use crate::{BackendRef, EndpointRef, RouteRef}; +use linkerd_app_core::{ + io, proxy::http, svc, tls::ServerName, transport::addrs::*, Addr, Error, NameAddr, Result, +}; +use linkerd_distribute as distribute; +use linkerd_proxy_client_policy as policy; +use linkerd_tls_route as tls_route; +use std::{fmt::Debug, hash::Hash, sync::Arc}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct Router { + pub(super) parent: T, + pub(super) addr: Addr, + pub(super) routes: Arc<[tls_route::Route>]>, + pub(super) backends: distribute::Backends>, +} + +type NewBackendCache = distribute::NewBackendCache, (), N, S>; + +// === impl Router === + +impl Router +where + // Parent target type. + T: Eq + Hash + Clone + Debug + Send + Sync + 'static, + T: svc::Param, +{ + pub fn layer() -> impl svc::Layer> + Clone + where + I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static, + // Concrete stack. + N: svc::NewService, Service = NSvc> + Clone + Send + Sync + 'static, + NSvc: svc::Service + Clone + Send + Sync + 'static, + NSvc::Future: Send, + NSvc::Error: Into, + { + svc::layer::mk(move |inner| { + svc::stack(inner) + .lift_new() + // Each route builds over concrete backends. All of these + // backends are cached here and shared across routes. + .push(NewBackendCache::layer()) + .push_on_service(route::MatchedRoute::layer()) + .push(svc::NewOneshotRoute::::layer_cached()) + .arc_new_clone_tcp() + .into_inner() + }) + } +} + +impl From<(crate::tls::Routes, T)> for Router +where + T: Eq + Hash + Clone + Debug, +{ + fn from((rts, parent): (crate::tls::Routes, T)) -> Self { + let crate::tls::Routes { + addr, + meta: parent_ref, + routes, + backends, + } = rts; + + let mk_concrete = { + let parent = parent.clone(); + let parent_ref = parent_ref.clone(); + + move |backend_ref: BackendRef, target: concrete::Dispatch| Concrete { + target, + parent: parent.clone(), + backend_ref, + parent_ref: parent_ref.clone(), + } + }; + + let mk_dispatch = move |bke: &policy::Backend| match bke.dispatcher { + policy::BackendDispatcher::BalanceP2c( + policy::Load::PeakEwma(policy::PeakEwma { decay, default_rtt }), + policy::EndpointDiscovery::DestinationGet { ref path }, + ) => mk_concrete( + BackendRef(bke.meta.clone()), + concrete::Dispatch::Balance( + path.parse::() + .expect("destination must be a nameaddr"), + http::balance::EwmaConfig { decay, default_rtt }, + ), + ), + policy::BackendDispatcher::Forward(addr, ref md) => mk_concrete( + EndpointRef::new(md, addr.port().try_into().expect("port must not be 0")).into(), + concrete::Dispatch::Forward(Remote(ServerAddr(addr)), md.clone()), + ), + policy::BackendDispatcher::Fail { ref message } => mk_concrete( + BackendRef(policy::Meta::new_default("fail")), + concrete::Dispatch::Fail { + message: message.clone(), + }, + ), + }; + + let mk_route_backend = + |route_ref: &RouteRef, rb: &policy::RouteBackend| { + let concrete = mk_dispatch(&rb.backend); + route::Backend { + route_ref: route_ref.clone(), + concrete, + } + }; + + let mk_distribution = + |rr: &RouteRef, d: &policy::RouteDistribution| match d { + policy::RouteDistribution::Empty => route::BackendDistribution::Empty, + policy::RouteDistribution::FirstAvailable(backends) => { + route::BackendDistribution::first_available( + backends.iter().map(|b| mk_route_backend(rr, b)), + ) + } + policy::RouteDistribution::RandomAvailable(backends) => { + route::BackendDistribution::random_available( + backends + .iter() + .map(|(rb, weight)| (mk_route_backend(rr, rb), *weight)), + ) + .expect("distribution must be valid") + } + }; + + let mk_policy = + |policy::RoutePolicy:: { + meta, distribution, .. + }| { + let route_ref = RouteRef(meta); + let parent_ref = parent_ref.clone(); + + let distribution = mk_distribution(&route_ref, &distribution); + route::Route { + addr: addr.clone(), + parent: parent.clone(), + parent_ref: parent_ref.clone(), + route_ref, + distribution, + } + }; + + let routes = routes + .iter() + .map(|route| tls_route::Route { + snis: route.snis.clone(), + rules: route + .rules + .iter() + .cloned() + .map(|tls_route::Rule { matches, policy }| tls_route::Rule { + matches, + policy: mk_policy(policy), + }) + .collect(), + }) + .collect(); + + let backends = backends.iter().map(mk_dispatch).collect(); + + Self { + routes, + backends, + addr, + parent, + } + } +} + +impl svc::router::SelectRoute for Router +where + T: Clone + Eq + Hash + Debug, + T: svc::Param, +{ + type Key = route::MatchedRoute; + type Error = NoRoute; + + fn select(&self, _: &I) -> Result { + use linkerd_tls_route::SessionInfo; + + let server_name: ServerName = self.parent.param(); + tracing::trace!("Selecting TLS route for {:?}", server_name); + let si = SessionInfo { sni: server_name }; + let (r#match, params) = policy::tls::find(&self.routes, si).ok_or(NoRoute)?; + tracing::debug!(meta = ?params.route_ref, "Selected route"); + tracing::trace!(?r#match); + + Ok(route::MatchedRoute { + r#match, + params: params.clone(), + }) + } +} + +impl svc::Param for Router +where + T: Eq + Hash + Clone + Debug, +{ + fn param(&self) -> LogicalAddr { + LogicalAddr(self.addr.clone()) + } +} + +impl svc::Param>> for Router +where + T: Eq + Hash + Clone + Debug, +{ + fn param(&self) -> distribute::Backends> { + self.backends.clone() + } +} diff --git a/linkerd/app/outbound/src/tls/logical/tests.rs b/linkerd/app/outbound/src/tls/logical/tests.rs new file mode 100644 index 000000000..f12aa283e --- /dev/null +++ b/linkerd/app/outbound/src/tls/logical/tests.rs @@ -0,0 +1,213 @@ +use super::{Outbound, ParentRef, Routes}; +use crate::test_util::*; +use linkerd_app_core::{ + io, + svc::{self, NewService}, + transport::addrs::*, + Result, +}; +use linkerd_app_test::{AsyncReadExt, AsyncWriteExt}; +use linkerd_proxy_client_policy::{self as client_policy, tls::sni}; +use parking_lot::Mutex; +use std::{ + collections::HashMap, + net::SocketAddr, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; +use tokio::sync::watch; + +mod basic; + +const REQUEST: &[u8] = b"who r u?"; +type Reponse = tokio::task::JoinHandle>; + +#[derive(Clone, Debug)] +struct Target { + num: usize, + routes: watch::Receiver, +} + +#[derive(Clone, Debug)] + +struct MockServer { + io: support::io::Builder, + addr: SocketAddr, +} + +#[derive(Clone, Debug, Default)] +struct ConnectTcp { + srvs: Arc>>, +} + +// === impl MockServer === + +impl MockServer { + fn new( + addr: SocketAddr, + service_name: &str, + client_hello: Vec, + ) -> (Self, io::DuplexStream, Reponse) { + let mut io = support::io(); + + io.write(&client_hello) + .write(REQUEST) + .read(service_name.as_bytes()); + + let server = MockServer { io, addr }; + let (io, response) = spawn_io(client_hello); + + (server, io, response) + } +} + +// === impl Target === + +impl PartialEq for Target { + fn eq(&self, other: &Self) -> bool { + self.num == other.num + } +} + +impl Eq for Target {} + +impl std::hash::Hash for Target { + fn hash(&self, state: &mut H) { + self.num.hash(state); + } +} + +impl svc::Param> for Target { + fn param(&self) -> watch::Receiver { + self.routes.clone() + } +} + +// === impl ConnectTcp === + +impl ConnectTcp { + fn add_server(&mut self, s: MockServer) { + self.srvs.lock().insert(s.addr, s); + } +} + +impl>> svc::Service for ConnectTcp { + type Response = (support::io::Mock, Local); + type Error = io::Error; + type Future = future::Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, t: T) -> Self::Future { + let Remote(ServerAddr(addr)) = t.param(); + let mut mock = self + .srvs + .lock() + .remove(&addr) + .expect("tried to connect to an unexpected address"); + + assert_eq!(addr, mock.addr); + let local = Local(ClientAddr(addr)); + future::ok::<_, support::io::Error>((mock.io.build(), local)) + } +} + +fn spawn_io( + client_hello: Vec, +) -> ( + io::DuplexStream, + tokio::task::JoinHandle>, +) { + let (mut client_io, server_io) = io::duplex(100); + let task = tokio::spawn(async move { + client_io.write_all(&client_hello).await?; + client_io.write_all(REQUEST).await?; + + let mut buf = String::with_capacity(100); + client_io.read_to_string(&mut buf).await?; + Ok(buf) + }); + (server_io, task) +} + +fn default_backend(addr: SocketAddr) -> client_policy::Backend { + use client_policy::{Backend, BackendDispatcher, EndpointMetadata, Meta, Queue}; + Backend { + meta: Meta::new_default("test"), + queue: Queue { + capacity: 100, + failfast_timeout: Duration::from_secs(10), + }, + dispatcher: BackendDispatcher::Forward(addr, EndpointMetadata::default()), + } +} + +fn sni_route(backend: client_policy::Backend, sni: sni::MatchSni) -> client_policy::tls::Route { + use client_policy::{ + tls::{Filter, Policy, Route, Rule}, + Meta, RouteBackend, RouteDistribution, + }; + use linkerd_tls_route::r#match::MatchSession; + use once_cell::sync::Lazy; + static NO_FILTERS: Lazy> = Lazy::new(|| Arc::new([])); + Route { + snis: vec![sni], + rules: vec![Rule { + matches: vec![MatchSession::default()], + policy: Policy { + meta: Meta::new_default("test_route"), + filters: NO_FILTERS.clone(), + params: (), + distribution: RouteDistribution::FirstAvailable(Arc::new([RouteBackend { + filters: NO_FILTERS.clone(), + backend, + }])), + }, + }], + } +} + +// generates a sample ClientHello TLS message for testing +fn generate_client_hello(sni: &str) -> Vec { + use tokio_rustls::rustls::{ + internal::msgs::{ + base::Payload, + enums::Compression, + handshake::{ + ClientExtension, ClientHelloPayload, HandshakeMessagePayload, HandshakePayload, + Random, SessionId, + }, + message::{MessagePayload, PlainMessage}, + }, + server::DnsName, + CipherSuite, ContentType, HandshakeType, ProtocolVersion, + }; + + let sni = DnsName::try_from(sni.to_string()).unwrap(); + + let hs_payload = HandshakeMessagePayload { + typ: HandshakeType::ClientHello, + payload: HandshakePayload::ClientHello(ClientHelloPayload { + client_version: ProtocolVersion::TLSv1_2, + random: Random::from([0; 32]), + session_id: SessionId::empty(), + cipher_suites: vec![CipherSuite::TLS_NULL_WITH_NULL_NULL], + compression_methods: vec![Compression::Null], + extensions: vec![ClientExtension::make_sni(sni.borrow())], + }), + }; + + let mut hs_payload_bytes = Vec::default(); + MessagePayload::handshake(hs_payload).encode(&mut hs_payload_bytes); + + let message = PlainMessage { + typ: ContentType::Handshake, + version: ProtocolVersion::TLSv1_2, + payload: Payload(hs_payload_bytes), + }; + + message.into_unencrypted_opaque().encode() +} diff --git a/linkerd/app/outbound/src/tls/logical/tests/basic.rs b/linkerd/app/outbound/src/tls/logical/tests/basic.rs new file mode 100644 index 000000000..1eb1150e7 --- /dev/null +++ b/linkerd/app/outbound/src/tls/logical/tests/basic.rs @@ -0,0 +1,73 @@ +use super::*; +use crate::tls::Tls; +use linkerd_app_core::{ + svc::ServiceExt, + tls::{NewDetectRequiredSni, ServerName}, + trace, NameAddr, +}; +use linkerd_proxy_client_policy as client_policy; +use std::{net::SocketAddr, str::FromStr, sync::Arc}; +use tokio::sync::watch; + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn routes() { + let _trace = trace::test::trace_init(); + + const AUTHORITY: &str = "logical.test.svc.cluster.local"; + const PORT: u16 = 666; + let addr = SocketAddr::new([192, 0, 2, 41].into(), PORT); + let dest: NameAddr = format!("{AUTHORITY}:{PORT}") + .parse::() + .expect("dest addr is valid"); + let resolve = support::resolver().endpoint_exists(dest.clone(), addr, Default::default()); + let (rt, _shutdown) = runtime(); + + let client_hello = generate_client_hello(AUTHORITY); + let (srv, io, rsp) = MockServer::new(addr, AUTHORITY, client_hello); + + let mut connect = ConnectTcp::default(); + connect.add_server(srv); + + let stack = Outbound::new(default_config(), rt, &mut Default::default()) + .with_stack(connect) + .push_tls_concrete(resolve) + .push_tls_logical() + .map_stack(|config, _rt, stk| { + stk.push_new_idle_cached(config.discovery_idle_timeout) + .push_map_target(|(sni, parent): (ServerName, _)| Tls { sni, parent }) + .push(NewDetectRequiredSni::layer(Duration::from_secs(1))) + .arc_new_clone_tcp() + }) + .into_inner(); + + let correct_backend = default_backend(addr); + let correct_route = sni_route( + correct_backend.clone(), + sni::MatchSni::Exact(AUTHORITY.into()), + ); + + let wrong_addr = SocketAddr::new([0, 0, 0, 0].into(), PORT); + let wrong_backend = default_backend(wrong_addr); + let wrong_route_1 = sni_route( + wrong_backend.clone(), + sni::MatchSni::from_str("foo").unwrap(), + ); + let wrong_route_2 = sni_route( + wrong_backend.clone(), + sni::MatchSni::from_str("*.test.svc.cluster.local").unwrap(), + ); + + let (_route_tx, routes) = watch::channel(Routes { + addr: addr.into(), + backends: Arc::new([correct_backend, wrong_backend]), + routes: Arc::new([correct_route, wrong_route_1, wrong_route_2]), + meta: ParentRef(client_policy::Meta::new_default("parent")), + }); + + let target = Target { num: 1, routes }; + let svc = stack.new_service(target); + + svc.oneshot(io).await.unwrap(); + let msg = rsp.await.unwrap().unwrap(); + assert_eq!(msg, AUTHORITY); +} diff --git a/linkerd/io/src/sensor.rs b/linkerd/io/src/sensor.rs index 7721bb6ac..9175b2ab3 100644 --- a/linkerd/io/src/sensor.rs +++ b/linkerd/io/src/sensor.rs @@ -1,4 +1,4 @@ -use crate::{IoSlice, PeerAddr, Poll}; +use crate::{IoSlice, Peek, PeerAddr, Poll}; use futures::ready; use linkerd_errno::Errno; use pin_project::pin_project; @@ -82,3 +82,10 @@ impl PeerAddr for SensorIo { self.io.peer_addr() } } + +#[async_trait::async_trait] +impl Peek for SensorIo { + async fn peek(&self, buf: &mut [u8]) -> Result { + self.io.peek(buf).await + } +} diff --git a/linkerd/tls/src/detect_sni.rs b/linkerd/tls/src/detect_sni.rs deleted file mode 100644 index 45747d63e..000000000 --- a/linkerd/tls/src/detect_sni.rs +++ /dev/null @@ -1,107 +0,0 @@ -use crate::{ - server::{detect_sni, DetectIo, Timeout}, - ServerName, -}; -use linkerd_error::Error; -use linkerd_io as io; -use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Service, ServiceExt}; -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; -use thiserror::Error; -use tokio::time; -use tracing::debug; - -#[derive(Clone, Debug, Error)] -#[error("SNI detection timed out")] -pub struct SniDetectionTimeoutError; - -#[derive(Clone, Debug, Error)] -#[error("Could not find SNI")] -pub struct NoSniFoundError; - -#[derive(Clone, Debug)] -pub struct NewDetectSni { - params: P, - inner: N, -} - -#[derive(Clone, Debug)] -pub struct DetectSni { - target: T, - inner: N, - timeout: Timeout, - params: P, -} - -impl NewDetectSni { - pub fn new(params: P, inner: N) -> Self { - Self { inner, params } - } - - pub fn layer(params: P) -> impl layer::Layer + Clone - where - P: Clone, - { - layer::mk(move |inner| Self::new(params.clone(), inner)) - } -} - -impl NewService for NewDetectSni -where - P: ExtractParam + Clone, - N: Clone, -{ - type Service = DetectSni; - - fn new_service(&self, target: T) -> Self::Service { - let timeout = self.params.extract_param(&target); - DetectSni { - target, - timeout, - inner: self.inner.clone(), - params: self.params.clone(), - } - } -} - -impl Service for DetectSni -where - T: Clone + Send + Sync + 'static, - P: InsertParam + Clone + Send + Sync + 'static, - P::Target: Send + 'static, - I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static, - N: NewService + Clone + Send + 'static, - S: Service> + Send, - S::Error: Into, - S::Future: Send, -{ - type Response = S::Response; - type Error = Error; - type Future = Pin> + Send + 'static>>; - - #[inline] - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, io: I) -> Self::Future { - let target = self.target.clone(); - let new_accept = self.inner.clone(); - let params = self.params.clone(); - - // Detect the SNI from a ClientHello (or timeout). - let Timeout(timeout) = self.timeout; - let detect = time::timeout(timeout, detect_sni(io)); - Box::pin(async move { - let (sni, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??; - let sni = sni.ok_or(NoSniFoundError)?; - - debug!("detected SNI: {:?}", sni); - let svc = new_accept.new_service(params.insert_param(sni, target)); - svc.oneshot(io).await.map_err(Into::into) - }) - } -} diff --git a/linkerd/tls/src/lib.rs b/linkerd/tls/src/lib.rs index 0a281e2b3..4d6b0f613 100755 --- a/linkerd/tls/src/lib.rs +++ b/linkerd/tls/src/lib.rs @@ -2,12 +2,14 @@ #![forbid(unsafe_code)] pub mod client; -pub mod detect_sni; pub mod server; pub use self::{ client::{Client, ClientTls, ConditionalClientTls, ConnectMeta, NoClientTls, ServerId}, - server::{ClientId, ConditionalServerTls, NewDetectTls, NoServerTls, ServerTls}, + server::{ + ClientId, ConditionalServerTls, NewDetectRequiredSni, NewDetectTls, NoServerTls, + NoSniFoundError, ServerTls, SniDetectionTimeoutError, + }, }; use linkerd_dns_name as dns; diff --git a/linkerd/tls/src/server.rs b/linkerd/tls/src/server.rs index 1c85c92ee..44bb498a3 100644 --- a/linkerd/tls/src/server.rs +++ b/linkerd/tls/src/server.rs @@ -1,4 +1,5 @@ mod client_hello; +mod required_sni; use crate::{NegotiatedProtocol, ServerName}; use bytes::BytesMut; @@ -18,6 +19,8 @@ use thiserror::Error; use tokio::time::{self, Duration}; use tracing::{debug, trace, warn}; +pub use self::required_sni::{NewDetectRequiredSni, NoSniFoundError, SniDetectionTimeoutError}; + /// Describes the authenticated identity of a remote client. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct ClientId(pub id::Id); @@ -65,6 +68,7 @@ pub struct NewDetectTls { _local_identity: std::marker::PhantomData L>, } +/// A param type used to indicate the timeout after which detection should fail. #[derive(Copy, Clone, Debug)] pub struct Timeout(pub Duration); diff --git a/linkerd/tls/src/server/required_sni.rs b/linkerd/tls/src/server/required_sni.rs new file mode 100644 index 000000000..1daf6d39b --- /dev/null +++ b/linkerd/tls/src/server/required_sni.rs @@ -0,0 +1,118 @@ +use crate::{ + server::{detect_sni, DetectIo}, + ServerName, +}; +use linkerd_error::Error; +use linkerd_io as io; +use linkerd_stack::{layer, NewService, Service, ServiceExt}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use thiserror::Error; +use tokio::time; +use tracing::debug; + +#[derive(Clone, Debug, Error)] +#[error("SNI detection timed out")] +pub struct SniDetectionTimeoutError; + +#[derive(Clone, Debug, Error)] +#[error("Could not find SNI")] +pub struct NoSniFoundError; + +/// A NewService that instruments an inner stack with knowledge of the +/// connection's TLS ServerName (i.e. from an SNI header). +/// +/// This differs from the parent module's NewDetectTls in a a few ways: +/// +/// - It requires that all connections have an SNI. +/// - It assumes that these connections may not be terminated locally, so there +/// is no concept of a local server name. +/// - There are no special affordances for mutually authenticated TLS, so we +/// make no attempt to detect the client's identity. +/// - The detection timeout is fixed and cannot vary per target (for +/// convenience, to reduce needless boilerplate). +#[derive(Clone, Debug)] +pub struct NewDetectRequiredSni { + inner: N, + timeout: time::Duration, +} + +#[derive(Clone, Debug)] +pub struct DetectRequiredSni { + target: T, + inner: N, + timeout: time::Duration, +} + +// === impl NewDetectRequiredSni === + +impl NewDetectRequiredSni { + fn new(timeout: time::Duration, inner: N) -> Self { + Self { inner, timeout } + } + + pub fn layer(timeout: time::Duration) -> impl layer::Layer + Clone { + layer::mk(move |inner| Self::new(timeout, inner)) + } +} + +impl NewService for NewDetectRequiredSni +where + N: Clone, +{ + type Service = DetectRequiredSni; + + fn new_service(&self, target: T) -> Self::Service { + DetectRequiredSni::new(self.timeout, target, self.inner.clone()) + } +} + +// === impl DetectRequiredSni === + +impl DetectRequiredSni { + fn new(timeout: time::Duration, target: T, inner: N) -> Self { + Self { + target, + inner, + timeout, + } + } +} + +impl Service for DetectRequiredSni +where + T: Clone + Send + Sync + 'static, + I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static, + N: NewService<(ServerName, T), Service = S> + Clone + Send + 'static, + S: Service> + Send, + S::Error: Into, + S::Future: Send, +{ + type Response = S::Response; + type Error = Error; + type Future = Pin> + Send + 'static>>; + + #[inline] + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, io: I) -> Self::Future { + let target = self.target.clone(); + let new_accept = self.inner.clone(); + + // Detect the SNI from a ClientHello (or timeout). + let detect = time::timeout(self.timeout, detect_sni(io)); + Box::pin(async move { + let (res, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??; + let sni = res.ok_or(NoSniFoundError)?; + debug!(?sni, "Detected TLS"); + + let svc = new_accept.new_service((sni, target)); + svc.oneshot(io).await.map_err(Into::into) + }) + } +}