feat(outbound): Route TLS connections by SNI (#3160)

This is a draft PR that wires together some of the recently introduced changes to the proxy in order to deliver TLS routing functionality. The changes are: 

This change adds a TLS routing stack that has the following properties: 
 - it always expects that `SNI` is present
 - uses tls routing information provided by the discovery API to perform routing based on SNI
 - does not concern itself with terminating TLS by simply proxies the encrypted stream

Signed-off-by: Zahari Dichev <zaharidichev@gmail.com>
This commit is contained in:
Zahari Dichev 2024-10-18 14:21:26 +03:00 committed by GitHub
parent b65c43aac5
commit 73e96ddb12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1480 additions and 120 deletions

View File

@ -1341,6 +1341,7 @@ dependencies = [
"linkerd-proxy-client-policy",
"linkerd-retry",
"linkerd-stack",
"linkerd-tls-route",
"linkerd-tonic-stream",
"linkerd-tonic-watch",
"linkerd-tracing",
@ -1351,6 +1352,7 @@ dependencies = [
"prometheus-client",
"thiserror",
"tokio",
"tokio-rustls",
"tokio-test",
"tonic",
"tower",

View File

@ -44,12 +44,14 @@ linkerd-proxy-client-policy = { path = "../../proxy/client-policy", features = [
"proto",
] }
linkerd-retry = { path = "../../retry" }
linkerd-tls-route = { path = "../../tls/route" }
linkerd-tonic-stream = { path = "../../tonic-stream" }
linkerd-tonic-watch = { path = "../../tonic-watch" }
[dev-dependencies]
hyper = { version = "0.14", features = ["http1", "http2"] }
tokio = { version = "1", features = ["macros", "sync", "time"] }
tokio-rustls = "0.24"
tokio-test = "0.4"
tower-test = "0.4"

View File

@ -33,10 +33,13 @@ pub use self::balance::BalancerMetrics;
pub enum Dispatch {
Balance(NameAddr, EwmaConfig),
Forward(Remote<ServerAddr>, Metadata),
Fail { message: Arc<str> },
/// A backend dispatcher that explicitly fails all requests.
Fail {
message: Arc<str>,
},
}
/// A backend dispatcher explicitly fails all requests.
/// A backend dispatcher that explicitly fails all requests.
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
pub struct DispatcherFailed(Arc<str>);

View File

@ -21,7 +21,7 @@ use linkerd_app_core::{
tap,
},
svc::{self, ServiceExt},
tls,
tls::ConnectMeta as TlsConnectMeta,
transport::addrs::*,
AddrMatch, Error, ProxyRuntime,
};
@ -46,6 +46,7 @@ mod sidecar;
pub mod tcp;
#[cfg(any(test, feature = "test-util"))]
pub mod test_util;
pub mod tls;
mod zone;
pub use self::discover::{spawn_synthesized_profile_policy, synthesize_forward_policy, Discovery};
@ -100,7 +101,7 @@ struct Runtime {
drain: drain::Watch,
}
pub type ConnectMeta = tls::ConnectMeta<Local<ClientAddr>>;
pub type ConnectMeta = TlsConnectMeta<Local<ClientAddr>>;
/// A reference to a frontend/apex resource, usually a service.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]

View File

@ -37,6 +37,7 @@ pub struct OutboundMetrics {
pub(crate) struct PromMetrics {
pub(crate) http: crate::http::HttpMetrics,
pub(crate) opaq: crate::opaq::OpaqMetrics,
pub(crate) tls: crate::tls::TlsMetrics,
pub(crate) zone: crate::zone::TcpZoneMetrics,
}
@ -92,8 +93,14 @@ impl PromMetrics {
let opaq = crate::opaq::OpaqMetrics::register(registry.sub_registry_with_prefix("tcp"));
let zone = crate::zone::TcpZoneMetrics::register(registry.sub_registry_with_prefix("tcp"));
let tls = crate::tls::TlsMetrics::register(registry.sub_registry_with_prefix("tls"));
Self { http, opaq, zone }
Self {
http,
opaq,
tls,
zone,
}
}
}

View File

@ -15,6 +15,7 @@ pub enum Protocol {
Http2,
Detect,
Opaque,
Tls,
}
// === impl Outbound ===
@ -29,6 +30,7 @@ impl<N> Outbound<N> {
pub fn push_protocol<T, I, NSvc>(
self,
http: svc::ArcNewCloneHttp<Http<T>>,
tls: svc::ArcNewCloneTcp<T, io::EitherIo<I, io::PrefixedIo<I>>>,
) -> Outbound<svc::ArcNewTcp<T, I>>
where
// Target type indicating whether detection should be skipped.
@ -83,7 +85,14 @@ impl<N> Outbound<N> {
http.map_stack(|_, _, http| {
// First separate traffic that needs protocol detection. Then switch
// between traffic that is known to be HTTP or opaque.
http.push_switch(Ok::<_, Infallible>, opaq.clone().into_inner())
let known = http.push_switch(
Ok::<_, Infallible>,
opaq.clone()
.push_switch(Ok::<_, Infallible>, tls.clone())
.into_inner(),
);
known
.push_on_service(svc::MapTargetLayer::new(io::EitherIo::Left))
.push_switch(
|parent: T| -> Result<_, Infallible> {
@ -96,7 +105,12 @@ impl<N> Outbound<N> {
version: http::Version::H2,
parent,
}))),
Protocol::Opaque => Ok(svc::Either::A(svc::Either::B(parent))),
Protocol::Opaque => {
Ok(svc::Either::A(svc::Either::B(svc::Either::A(parent))))
}
Protocol::Tls => {
Ok(svc::Either::A(svc::Either::B(svc::Either::B(parent))))
}
Protocol::Detect => Ok(svc::Either::B(parent)),
}
},

View File

@ -1,7 +1,7 @@
use crate::{
http, opaq, policy,
protocol::{self, Protocol},
Discovery, Outbound, ParentRef,
tls, Discovery, Outbound, ParentRef,
};
use linkerd_app_core::{
io, profiles,
@ -32,6 +32,12 @@ struct HttpSidecar {
routes: watch::Receiver<http::Routes>,
}
#[derive(Clone, Debug)]
struct TlsSidecar {
orig_dst: OrigDstAddr,
routes: watch::Receiver<tls::Routes>,
}
// === impl Outbound ===
impl Outbound<()> {
@ -53,6 +59,12 @@ impl Outbound<()> {
R::Resolution: Unpin,
{
let opaq = self.to_tcp_connect().push_opaq_cached(resolve.clone());
let tls = self
.to_tcp_connect()
.push_tls_cached(resolve.clone())
.into_stack()
.push_map_target(TlsSidecar::from)
.arc_new_clone_tcp();
let http = self
.to_tcp_connect()
@ -64,7 +76,8 @@ impl Outbound<()> {
.push_map_target(HttpSidecar::from)
.arc_new_clone_http();
opaq.push_protocol(http.into_inner())
opaq.clone()
.push_protocol(http.into_inner(), tls.into_inner())
// Use a dedicated target type to bind discovery results to the
// outbound sidecar stack configuration.
.map_stack(move |_, _, stk| stk.push_map_target(Sidecar::from))
@ -131,7 +144,8 @@ impl svc::Param<Protocol> for Sidecar {
match self.policy.borrow().protocol {
policy::Protocol::Http1(_) => Protocol::Http1,
policy::Protocol::Http2(_) | policy::Protocol::Grpc(_) => Protocol::Http2,
policy::Protocol::Opaque(_) | policy::Protocol::Tls(_) => Protocol::Opaque,
policy::Protocol::Opaque(_) => Protocol::Opaque,
policy::Protocol::Tls(_) => Protocol::Tls,
policy::Protocol::Detect { .. } => Protocol::Detect,
}
}
@ -326,3 +340,62 @@ impl std::hash::Hash for HttpSidecar {
self.version.hash(state);
}
}
// === impl TlsSidecar ===
impl From<Sidecar> for TlsSidecar {
fn from(parent: Sidecar) -> Self {
let orig_dst = parent.orig_dst;
let mut policy = parent.policy.clone();
let init = Self::mk_policy_routes(orig_dst, &policy.borrow_and_update())
.expect("initial policy must be tls");
let routes = tls::spawn_routes(policy, init, move |policy: &policy::ClientPolicy| {
Self::mk_policy_routes(orig_dst, policy)
});
TlsSidecar { orig_dst, routes }
}
}
impl TlsSidecar {
fn mk_policy_routes(
OrigDstAddr(orig_dst): OrigDstAddr,
policy: &policy::ClientPolicy,
) -> Option<tls::Routes> {
let parent_ref = ParentRef(policy.parent.clone());
let routes = match policy.protocol {
policy::Protocol::Tls(policy::tls::Tls { ref routes }) => routes.clone(),
_ => {
tracing::info!("Ignoring a discovery update that changed a route from TLS");
return None;
}
};
Some(tls::Routes {
addr: orig_dst.into(),
meta: parent_ref,
routes,
backends: policy.backends.clone(),
})
}
}
impl svc::Param<watch::Receiver<tls::Routes>> for TlsSidecar {
fn param(&self) -> watch::Receiver<tls::Routes> {
self.routes.clone()
}
}
impl std::cmp::PartialEq for TlsSidecar {
fn eq(&self, other: &Self) -> bool {
self.orig_dst == other.orig_dst
}
}
impl std::cmp::Eq for TlsSidecar {}
impl std::hash::Hash for TlsSidecar {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.orig_dst.hash(state);
}
}

View File

@ -0,0 +1,134 @@
use crate::{tcp, Outbound};
use linkerd_app_core::{
io,
metrics::prom,
proxy::{
api_resolve::{ConcreteAddr, Metadata},
core::Resolve,
},
svc,
tls::{NewDetectRequiredSni, ServerName},
transport::addrs::*,
Error,
};
use std::{fmt::Debug, hash::Hash};
use tokio::sync::watch;
mod concrete;
mod logical;
pub use self::logical::{Concrete, Routes};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct Tls<T> {
sni: ServerName,
parent: T,
}
pub fn spawn_routes<T>(
mut route_rx: watch::Receiver<T>,
init: Routes,
mut mk: impl FnMut(&T) -> Option<Routes> + Send + Sync + 'static,
) -> watch::Receiver<Routes>
where
T: Send + Sync + 'static,
{
let (tx, rx) = watch::channel(init);
tokio::spawn(async move {
loop {
let res = tokio::select! {
biased;
_ = tx.closed() => return,
res = route_rx.changed() => res,
};
if res.is_err() {
// Drop the `tx` sender when the profile sender is
// dropped.
return;
}
if let Some(routes) = (mk)(&*route_rx.borrow_and_update()) {
if tx.send(routes).is_err() {
// Drop the `tx` sender when all of its receivers are dropped.
return;
}
}
}
});
rx
}
#[derive(Clone, Debug, Default)]
pub struct TlsMetrics {
balance: concrete::BalancerMetrics,
}
// === impl Outbound ===
impl<C> Outbound<C> {
/// Builds a stack that proxies TLS connections.
///
/// This stack uses caching so that a router/load-balancer may be reused
/// across multiple connections.
pub fn push_tls_cached<T, I, R>(self, resolve: R) -> Outbound<svc::ArcNewCloneTcp<T, I>>
where
// Tls target
T: Clone + Debug + PartialEq + Eq + Hash + Send + Sync + 'static,
T: svc::Param<watch::Receiver<Routes>>,
// Server-side connection
I: io::AsyncRead + io::AsyncWrite + io::PeerAddr + io::Peek,
I: Debug + Send + Sync + Unpin + 'static,
// Endpoint discovery
R: Resolve<ConcreteAddr, Endpoint = Metadata, Error = Error>,
R::Resolution: Unpin,
// TCP endpoint stack.
C: svc::MakeConnection<tcp::Connect, Metadata = Local<ClientAddr>, Error = io::Error>,
C: Clone + Send + Sync + Unpin + 'static,
C::Connection: Send + Unpin,
C::Future: Send + Unpin,
{
self.push_tcp_endpoint()
.push_tls_concrete(resolve)
.push_tls_logical()
.map_stack(|config, _rt, stk| {
stk.push_new_idle_cached(config.discovery_idle_timeout)
// Use a dedicated target type to configure parameters for
// the TLS stack. It also helps narrow the cache key.
.push_map_target(|(sni, parent): (ServerName, T)| Tls { sni, parent })
.push(NewDetectRequiredSni::layer(
config.proxy.detect_protocol_timeout,
))
.arc_new_clone_tcp()
})
}
}
// === impl Tls ===
impl<T> svc::Param<ServerName> for Tls<T> {
fn param(&self) -> ServerName {
self.sni.clone()
}
}
impl<T> svc::Param<watch::Receiver<logical::Routes>> for Tls<T>
where
T: svc::Param<watch::Receiver<logical::Routes>>,
{
fn param(&self) -> watch::Receiver<logical::Routes> {
self.parent.param()
}
}
// === impl TlsMetrics ===
impl TlsMetrics {
pub fn register(registry: &mut prom::Registry) -> Self {
let balance =
concrete::BalancerMetrics::register(registry.sub_registry_with_prefix("balancer"));
Self { balance }
}
}

View File

@ -0,0 +1,387 @@
use crate::{
metrics::BalancerMetricsParams,
stack_labels,
zone::{tcp_zone_labels, TcpZoneLabels},
BackendRef, Outbound, ParentRef,
};
use linkerd_app_core::{
config::QueueConfig,
drain, io,
metrics::{
self,
prom::{self, EncodeLabelSetMut},
OutboundZoneLocality,
},
proxy::{
api_resolve::{ConcreteAddr, Metadata},
core::Resolve,
http::AuthorityOverride,
tcp::{self, balance},
},
svc::{self, layer::Layer},
tls::{self, ServerName},
transport::{self, addrs::*},
transport_header::SessionProtocol,
Error, Infallible, NameAddr,
};
use std::{fmt::Debug, net::SocketAddr, sync::Arc};
use tracing::info_span;
/// Parameter configuring dispatcher behavior.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Dispatch {
Balance(NameAddr, balance::EwmaConfig),
Forward(Remote<ServerAddr>, Metadata),
/// A backend dispatcher that explicitly fails all requests.
Fail {
message: Arc<str>,
},
}
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
pub struct DispatcherFailed(Arc<str>);
/// Wraps errors encountered in this module.
#[derive(Debug, thiserror::Error)]
#[error("concrete service {addr}: {source}")]
pub struct ConcreteError {
addr: NameAddr,
#[source]
source: Error,
}
/// Inner stack target type.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Endpoint<T> {
addr: Remote<ServerAddr>,
is_local: bool,
metadata: Metadata,
parent: T,
}
pub type BalancerMetrics = BalancerMetricsParams<ConcreteLabels>;
/// A target configuring a load balancer stack.
#[derive(Clone, Debug, PartialEq, Eq)]
struct Balance<T> {
concrete: NameAddr,
ewma: balance::EwmaConfig,
queue: QueueConfig,
parent: T,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct ConcreteLabels {
concrete: Arc<str>,
}
impl prom::EncodeLabelSetMut for ConcreteLabels {
fn encode_label_set(&self, enc: &mut prom::encoding::LabelSetEncoder<'_>) -> std::fmt::Result {
use prom::encoding::EncodeLabel;
("concrete", &*self.concrete).encode(enc.encode_label())?;
Ok(())
}
}
impl prom::encoding::EncodeLabelSet for ConcreteLabels {
fn encode(&self, mut enc: prom::encoding::LabelSetEncoder<'_>) -> std::fmt::Result {
self.encode_label_set(&mut enc)
}
}
impl<T> svc::ExtractParam<balance::Metrics, Balance<T>> for BalancerMetricsParams<ConcreteLabels> {
fn extract_param(&self, bal: &Balance<T>) -> balance::Metrics {
self.metrics(&ConcreteLabels {
concrete: bal.concrete.to_string().into(),
})
}
}
// === impl Outbound ===
impl<C> Outbound<C> {
/// Builds a [`svc::NewService`] stack that builds buffered tls services
/// for `T`-typed concrete targets. Connections may be load balanced across
/// a discovered set of replicas or forwarded to a single endpoint,
/// depending on the value of the `Dispatch` parameter.
///
/// When a balancer has no available inner services, it goes into
/// 'failfast'. While in failfast, buffered requests are failed and the
/// service becomes unavailable so callers may choose alternate concrete
/// services.
pub fn push_tls_concrete<T, I, R>(
self,
resolve: R,
) -> Outbound<
svc::ArcNewService<
T,
impl svc::Service<I, Response = (), Error = Error, Future = impl Send> + Clone,
>,
>
where
// Logical target
T: svc::Param<Dispatch>,
T: Clone + Debug + Send + Sync + 'static,
T: svc::Param<ServerName>,
// Server-side socket.
I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static,
// Endpoint resolution.
R: Resolve<ConcreteAddr, Endpoint = Metadata, Error = Error>,
R::Resolution: Unpin,
// Endpoint connector.
C: svc::MakeConnection<Endpoint<T>> + Clone + Send + 'static,
C::Connection: Send + Unpin,
C::Metadata: Send + Unpin,
C::Future: Send,
C: Send + Sync + 'static,
{
let resolve =
svc::MapTargetLayer::new(|t: Balance<T>| -> ConcreteAddr { ConcreteAddr(t.concrete) })
.layer(resolve.into_service());
self.map_stack(|config, rt, inner| {
let queue = config.tcp_connection_queue;
let connect = inner
.push(svc::stack::WithoutConnectionMetadata::layer())
.push_new_thunk();
let forward = connect
.clone()
.push_on_service(rt.metrics.proxy.stack.layer(stack_labels("tls", "forward")))
.instrument(|e: &Endpoint<T>| info_span!("forward", addr = %e.addr));
let endpoint = connect
.push_on_service(
rt.metrics
.proxy
.stack
.layer(stack_labels("tls", "endpoint")),
)
.instrument(|e: &Endpoint<T>| info_span!("endpoint", addr = %e.addr));
let fail = svc::ArcNewService::new(|message: Arc<str>| {
svc::mk(move |_| futures::future::ready(Err(DispatcherFailed(message.clone()))))
});
let inbound_ips = config.inbound_ips.clone();
let balance = endpoint
.push_map_target(
move |((addr, metadata), target): ((SocketAddr, Metadata), Balance<T>)| {
tracing::trace!(%addr, ?metadata, ?target, "Resolved endpoint");
let is_local = inbound_ips.contains(&addr.ip());
Endpoint {
addr: Remote(ServerAddr(addr)),
metadata,
is_local,
parent: target.parent,
}
},
)
.lift_new_with_target()
.push(tcp::NewBalance::layer(
resolve,
rt.metrics.prom.tls.balance.clone(),
))
.push(svc::NewMapErr::layer_from_target::<ConcreteError, _>())
.push_on_service(rt.metrics.proxy.stack.layer(stack_labels("tls", "balance")))
.instrument(|t: &Balance<T>| info_span!("balance", addr = %t.concrete));
balance
.push_switch(Ok::<_, Infallible>, forward.into_inner())
.push_switch(
move |parent: T| -> Result<_, Infallible> {
Ok(match parent.param() {
Dispatch::Balance(concrete, ewma) => {
svc::Either::A(svc::Either::A(Balance {
concrete,
ewma,
queue,
parent,
}))
}
Dispatch::Forward(addr, meta) => {
svc::Either::A(svc::Either::B(Endpoint {
addr,
is_local: false,
metadata: meta,
parent,
}))
}
Dispatch::Fail { message } => svc::Either::B(message),
})
},
svc::stack(fail).check_new_clone().into_inner(),
)
.push_on_service(tcp::Forward::layer())
.push_on_service(drain::Retain::layer(rt.drain.clone()))
.push(svc::ArcNewService::layer())
})
}
}
// === impl ConcreteError ===
impl<T> From<(&Balance<T>, Error)> for ConcreteError {
fn from((target, source): (&Balance<T>, Error)) -> Self {
Self {
addr: target.concrete.clone(),
source,
}
}
}
// === impl Balance ===
impl<T> std::ops::Deref for Balance<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.parent
}
}
impl<T> svc::Param<balance::EwmaConfig> for Balance<T> {
fn param(&self) -> balance::EwmaConfig {
self.ewma
}
}
impl<T> svc::Param<svc::queue::Capacity> for Balance<T> {
fn param(&self) -> svc::queue::Capacity {
svc::queue::Capacity(self.queue.capacity)
}
}
impl<T> svc::Param<svc::queue::Timeout> for Balance<T> {
fn param(&self) -> svc::queue::Timeout {
svc::queue::Timeout(self.queue.failfast_timeout)
}
}
impl<T: svc::Param<ParentRef>> svc::Param<ParentRef> for Balance<T> {
fn param(&self) -> ParentRef {
self.parent.param()
}
}
impl<T: svc::Param<BackendRef>> svc::Param<BackendRef> for Balance<T> {
fn param(&self) -> BackendRef {
self.parent.param()
}
}
// === impl Endpoint ===
impl<T> svc::Param<Remote<ServerAddr>> for Endpoint<T> {
fn param(&self) -> Remote<ServerAddr> {
self.addr
}
}
impl<T> svc::Param<Option<crate::tcp::tagged_transport::PortOverride>> for Endpoint<T> {
fn param(&self) -> Option<crate::tcp::tagged_transport::PortOverride> {
if self.is_local {
return None;
}
self.metadata
.tagged_transport_port()
.map(crate::tcp::tagged_transport::PortOverride)
}
}
impl<T> svc::Param<Option<AuthorityOverride>> for Endpoint<T> {
fn param(&self) -> Option<AuthorityOverride> {
if self.is_local {
return None;
}
self.metadata
.authority_override()
.cloned()
.map(AuthorityOverride)
}
}
impl<T> svc::Param<Option<SessionProtocol>> for Endpoint<T> {
fn param(&self) -> Option<SessionProtocol> {
None
}
}
impl<T> svc::Param<transport::labels::Key> for Endpoint<T>
where
T: svc::Param<ServerName>,
{
fn param(&self) -> transport::labels::Key {
transport::labels::Key::OutboundClient(self.param())
}
}
impl<T> svc::Param<metrics::OutboundEndpointLabels> for Endpoint<T>
where
T: svc::Param<ServerName>,
{
fn param(&self) -> metrics::OutboundEndpointLabels {
metrics::OutboundEndpointLabels {
authority: None,
labels: metrics::prefix_labels("dst", self.metadata.labels().iter()),
zone_locality: self.param(),
server_id: self.param(),
target_addr: self.addr.into(),
}
}
}
impl<T> svc::Param<OutboundZoneLocality> for Endpoint<T> {
fn param(&self) -> OutboundZoneLocality {
OutboundZoneLocality::new(&self.metadata)
}
}
impl<T> svc::Param<TcpZoneLabels> for Endpoint<T> {
fn param(&self) -> TcpZoneLabels {
tcp_zone_labels(self.param())
}
}
impl<T> svc::Param<metrics::EndpointLabels> for Endpoint<T>
where
T: svc::Param<ServerName>,
{
fn param(&self) -> metrics::EndpointLabels {
metrics::EndpointLabels::from(svc::Param::<metrics::OutboundEndpointLabels>::param(self))
}
}
impl<T> svc::Param<tls::ConditionalClientTls> for Endpoint<T> {
fn param(&self) -> tls::ConditionalClientTls {
if self.is_local {
return tls::ConditionalClientTls::None(tls::NoClientTls::Loopback);
}
// If we're transporting an opaque protocol OR we're communicating with
// a gateway, then set an ALPN value indicating support for a transport
// header.
let use_transport_header = self.metadata.tagged_transport_port().is_some()
|| self.metadata.authority_override().is_some();
self.metadata
.identity()
.cloned()
.map(move |mut client_tls| {
client_tls.alpn = if use_transport_header {
use linkerd_app_core::transport_header::PROTOCOL;
Some(tls::client::AlpnProtocols(vec![PROTOCOL.into()]))
} else {
None
};
tls::ConditionalClientTls::Some(client_tls)
})
.unwrap_or(tls::ConditionalClientTls::None(
tls::NoClientTls::NotProvidedByServiceDiscovery,
))
}
}

View File

@ -0,0 +1,112 @@
use super::concrete;
use crate::{BackendRef, Outbound, ParentRef};
use linkerd_app_core::{io, svc, tls::ServerName, Addr, Error};
use linkerd_proxy_client_policy as client_policy;
use std::{fmt::Debug, hash::Hash, sync::Arc};
use tokio::sync::watch;
pub mod route;
pub mod router;
#[cfg(test)]
mod tests;
/// Indicates the address used for logical routing.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct LogicalAddr(pub Addr);
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Routes {
pub addr: Addr,
pub meta: ParentRef,
pub routes: Arc<[client_policy::tls::Route]>,
pub backends: Arc<[client_policy::Backend]>,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Concrete<T> {
target: concrete::Dispatch,
parent: T,
parent_ref: ParentRef,
backend_ref: BackendRef,
}
#[derive(Debug, thiserror::Error)]
#[error("no route")]
pub struct NoRoute;
#[derive(Debug, thiserror::Error)]
#[error("logical service {addr}: {source}")]
pub struct LogicalError {
addr: Addr,
#[source]
source: Error,
}
impl<N> Outbound<N> {
/// Builds a `NewService` that produces a router service for each logical
/// target.
///
/// The router uses discovery information (provided on the target) to
/// support per-connection routing over a set of concrete inner services.
/// Only available inner services are used for routing. When there are no
/// available backends, requests are failed with a [`svc::stack::LoadShedError`].
pub fn push_tls_logical<T, I, NSvc>(self) -> Outbound<svc::ArcNewCloneTcp<T, I>>
where
// Logical target.
T: svc::Param<watch::Receiver<Routes>>,
T: svc::Param<ServerName>,
T: Eq + Hash + Clone + Debug + Send + Sync + 'static,
// Concrete stack.
I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static,
// Concrete stack.
N: svc::NewService<Concrete<T>, Service = NSvc> + Clone + Send + Sync + 'static,
NSvc: svc::Service<I, Response = ()> + Clone + Send + Sync + 'static,
NSvc::Future: Send,
NSvc::Error: Into<Error>,
{
self.map_stack(|_config, _, concrete| {
concrete
.lift_new()
.push_on_service(svc::layer::mk(move |concrete: N| {
svc::stack(concrete.clone())
.push(router::Router::layer())
.push(svc::NewMapErr::layer_from_target::<LogicalError, _>())
.arc_new_clone_tcp()
.into_inner()
}))
// Rebuild the inner router stack every time the watch changes.
.push(svc::NewSpawnWatch::<Routes, _>::layer_into::<
router::Router<T>,
>())
.arc_new_clone_tcp()
})
}
}
// === impl LogicalError ===
impl<T> From<(&router::Router<T>, Error)> for LogicalError
where
T: Eq + Hash + Clone + Debug,
{
fn from((target, source): (&router::Router<T>, Error)) -> Self {
let LogicalAddr(addr) = svc::Param::param(target);
Self { addr, source }
}
}
impl<T> svc::Param<concrete::Dispatch> for Concrete<T> {
fn param(&self) -> concrete::Dispatch {
self.target.clone()
}
}
impl<T> svc::Param<ServerName> for Concrete<T>
where
T: svc::Param<ServerName>,
{
fn param(&self) -> ServerName {
self.parent.param()
}
}

View File

@ -0,0 +1,100 @@
use super::super::Concrete;
use crate::{ParentRef, RouteRef};
use linkerd_app_core::{io, svc, Addr, Error};
use linkerd_distribute as distribute;
use linkerd_tls_route as tls_route;
use std::{fmt::Debug, hash::Hash};
#[derive(Debug, PartialEq, Eq, Hash)]
pub(crate) struct Backend<T> {
pub(crate) route_ref: RouteRef,
pub(crate) concrete: Concrete<T>,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct MatchedRoute<T> {
pub(super) r#match: tls_route::RouteMatch,
pub(super) params: Route<T>,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct Route<T> {
pub(super) parent: T,
pub(super) addr: Addr,
pub(super) parent_ref: ParentRef,
pub(super) route_ref: RouteRef,
pub(super) distribution: BackendDistribution<T>,
}
pub(crate) type BackendDistribution<T> = distribute::Distribution<Backend<T>>;
pub(crate) type NewDistribute<T, N> = distribute::NewDistribute<Backend<T>, (), N>;
/// Wraps errors with route metadata.
#[derive(Debug, thiserror::Error)]
#[error("route {}: {source}", route.0)]
struct RouteError {
route: RouteRef,
#[source]
source: Error,
}
// === impl Backend ===
impl<T: Clone> Clone for Backend<T> {
fn clone(&self) -> Self {
Self {
route_ref: self.route_ref.clone(),
concrete: self.concrete.clone(),
}
}
}
// === impl MatchedRoute ===
impl<T> MatchedRoute<T>
where
// Parent target.
T: Debug + Eq + Hash,
T: Clone + Send + Sync + 'static,
{
/// Builds a route stack that applies policy filters to requests and
/// distributes requests over each route's backends. These [`Concrete`]
/// backends are expected to be cached/shared by the inner stack.
pub(crate) fn layer<N, I, NSvc>(
) -> impl svc::Layer<N, Service = svc::ArcNewCloneTcp<Self, I>> + Clone
where
I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static,
// Inner stack.
N: svc::NewService<Concrete<T>, Service = NSvc> + Clone + Send + Sync + 'static,
NSvc: svc::Service<I, Response = ()> + Clone + Send + Sync + 'static,
NSvc::Future: Send,
NSvc::Error: Into<Error>,
{
svc::layer::mk(move |inner| {
svc::stack(inner)
.push_map_target(|t| t)
.push_map_target(|b: Backend<T>| b.concrete)
.lift_new()
.push(NewDistribute::layer())
// The router does not take the backend's availability into
// consideration, so we must eagerly fail requests to prevent
// leaking tasks onto the runtime.
.push_on_service(svc::LoadShed::layer())
.push(svc::NewMapErr::layer_with(|rt: &Self| {
let route = rt.params.route_ref.clone();
move |source| RouteError {
route: route.clone(),
source,
}
}))
.arc_new_clone_tcp()
.into_inner()
})
}
}
impl<T: Clone> svc::Param<BackendDistribution<T>> for MatchedRoute<T> {
fn param(&self) -> BackendDistribution<T> {
self.params.distribution.clone()
}
}

View File

@ -0,0 +1,215 @@
use super::{
super::{concrete, Concrete},
route, LogicalAddr, NoRoute,
};
use crate::{BackendRef, EndpointRef, RouteRef};
use linkerd_app_core::{
io, proxy::http, svc, tls::ServerName, transport::addrs::*, Addr, Error, NameAddr, Result,
};
use linkerd_distribute as distribute;
use linkerd_proxy_client_policy as policy;
use linkerd_tls_route as tls_route;
use std::{fmt::Debug, hash::Hash, sync::Arc};
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct Router<T: Clone + Debug + Eq + Hash> {
pub(super) parent: T,
pub(super) addr: Addr,
pub(super) routes: Arc<[tls_route::Route<route::Route<T>>]>,
pub(super) backends: distribute::Backends<Concrete<T>>,
}
type NewBackendCache<T, N, S> = distribute::NewBackendCache<Concrete<T>, (), N, S>;
// === impl Router ===
impl<T> Router<T>
where
// Parent target type.
T: Eq + Hash + Clone + Debug + Send + Sync + 'static,
T: svc::Param<ServerName>,
{
pub fn layer<N, I, NSvc>() -> impl svc::Layer<N, Service = svc::ArcNewCloneTcp<Self, I>> + Clone
where
I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static,
// Concrete stack.
N: svc::NewService<Concrete<T>, Service = NSvc> + Clone + Send + Sync + 'static,
NSvc: svc::Service<I, Response = ()> + Clone + Send + Sync + 'static,
NSvc::Future: Send,
NSvc::Error: Into<Error>,
{
svc::layer::mk(move |inner| {
svc::stack(inner)
.lift_new()
// Each route builds over concrete backends. All of these
// backends are cached here and shared across routes.
.push(NewBackendCache::layer())
.push_on_service(route::MatchedRoute::layer())
.push(svc::NewOneshotRoute::<Self, (), _>::layer_cached())
.arc_new_clone_tcp()
.into_inner()
})
}
}
impl<T> From<(crate::tls::Routes, T)> for Router<T>
where
T: Eq + Hash + Clone + Debug,
{
fn from((rts, parent): (crate::tls::Routes, T)) -> Self {
let crate::tls::Routes {
addr,
meta: parent_ref,
routes,
backends,
} = rts;
let mk_concrete = {
let parent = parent.clone();
let parent_ref = parent_ref.clone();
move |backend_ref: BackendRef, target: concrete::Dispatch| Concrete {
target,
parent: parent.clone(),
backend_ref,
parent_ref: parent_ref.clone(),
}
};
let mk_dispatch = move |bke: &policy::Backend| match bke.dispatcher {
policy::BackendDispatcher::BalanceP2c(
policy::Load::PeakEwma(policy::PeakEwma { decay, default_rtt }),
policy::EndpointDiscovery::DestinationGet { ref path },
) => mk_concrete(
BackendRef(bke.meta.clone()),
concrete::Dispatch::Balance(
path.parse::<NameAddr>()
.expect("destination must be a nameaddr"),
http::balance::EwmaConfig { decay, default_rtt },
),
),
policy::BackendDispatcher::Forward(addr, ref md) => mk_concrete(
EndpointRef::new(md, addr.port().try_into().expect("port must not be 0")).into(),
concrete::Dispatch::Forward(Remote(ServerAddr(addr)), md.clone()),
),
policy::BackendDispatcher::Fail { ref message } => mk_concrete(
BackendRef(policy::Meta::new_default("fail")),
concrete::Dispatch::Fail {
message: message.clone(),
},
),
};
let mk_route_backend =
|route_ref: &RouteRef, rb: &policy::RouteBackend<policy::tls::Filter>| {
let concrete = mk_dispatch(&rb.backend);
route::Backend {
route_ref: route_ref.clone(),
concrete,
}
};
let mk_distribution =
|rr: &RouteRef, d: &policy::RouteDistribution<policy::tls::Filter>| match d {
policy::RouteDistribution::Empty => route::BackendDistribution::Empty,
policy::RouteDistribution::FirstAvailable(backends) => {
route::BackendDistribution::first_available(
backends.iter().map(|b| mk_route_backend(rr, b)),
)
}
policy::RouteDistribution::RandomAvailable(backends) => {
route::BackendDistribution::random_available(
backends
.iter()
.map(|(rb, weight)| (mk_route_backend(rr, rb), *weight)),
)
.expect("distribution must be valid")
}
};
let mk_policy =
|policy::RoutePolicy::<policy::tls::Filter, ()> {
meta, distribution, ..
}| {
let route_ref = RouteRef(meta);
let parent_ref = parent_ref.clone();
let distribution = mk_distribution(&route_ref, &distribution);
route::Route {
addr: addr.clone(),
parent: parent.clone(),
parent_ref: parent_ref.clone(),
route_ref,
distribution,
}
};
let routes = routes
.iter()
.map(|route| tls_route::Route {
snis: route.snis.clone(),
rules: route
.rules
.iter()
.cloned()
.map(|tls_route::Rule { matches, policy }| tls_route::Rule {
matches,
policy: mk_policy(policy),
})
.collect(),
})
.collect();
let backends = backends.iter().map(mk_dispatch).collect();
Self {
routes,
backends,
addr,
parent,
}
}
}
impl<T, I> svc::router::SelectRoute<I> for Router<T>
where
T: Clone + Eq + Hash + Debug,
T: svc::Param<ServerName>,
{
type Key = route::MatchedRoute<T>;
type Error = NoRoute;
fn select(&self, _: &I) -> Result<Self::Key, Self::Error> {
use linkerd_tls_route::SessionInfo;
let server_name: ServerName = self.parent.param();
tracing::trace!("Selecting TLS route for {:?}", server_name);
let si = SessionInfo { sni: server_name };
let (r#match, params) = policy::tls::find(&self.routes, si).ok_or(NoRoute)?;
tracing::debug!(meta = ?params.route_ref, "Selected route");
tracing::trace!(?r#match);
Ok(route::MatchedRoute {
r#match,
params: params.clone(),
})
}
}
impl<T> svc::Param<LogicalAddr> for Router<T>
where
T: Eq + Hash + Clone + Debug,
{
fn param(&self) -> LogicalAddr {
LogicalAddr(self.addr.clone())
}
}
impl<T> svc::Param<distribute::Backends<Concrete<T>>> for Router<T>
where
T: Eq + Hash + Clone + Debug,
{
fn param(&self) -> distribute::Backends<Concrete<T>> {
self.backends.clone()
}
}

View File

@ -0,0 +1,213 @@
use super::{Outbound, ParentRef, Routes};
use crate::test_util::*;
use linkerd_app_core::{
io,
svc::{self, NewService},
transport::addrs::*,
Result,
};
use linkerd_app_test::{AsyncReadExt, AsyncWriteExt};
use linkerd_proxy_client_policy::{self as client_policy, tls::sni};
use parking_lot::Mutex;
use std::{
collections::HashMap,
net::SocketAddr,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use tokio::sync::watch;
mod basic;
const REQUEST: &[u8] = b"who r u?";
type Reponse = tokio::task::JoinHandle<io::Result<String>>;
#[derive(Clone, Debug)]
struct Target {
num: usize,
routes: watch::Receiver<Routes>,
}
#[derive(Clone, Debug)]
struct MockServer {
io: support::io::Builder,
addr: SocketAddr,
}
#[derive(Clone, Debug, Default)]
struct ConnectTcp {
srvs: Arc<Mutex<HashMap<SocketAddr, MockServer>>>,
}
// === impl MockServer ===
impl MockServer {
fn new(
addr: SocketAddr,
service_name: &str,
client_hello: Vec<u8>,
) -> (Self, io::DuplexStream, Reponse) {
let mut io = support::io();
io.write(&client_hello)
.write(REQUEST)
.read(service_name.as_bytes());
let server = MockServer { io, addr };
let (io, response) = spawn_io(client_hello);
(server, io, response)
}
}
// === impl Target ===
impl PartialEq for Target {
fn eq(&self, other: &Self) -> bool {
self.num == other.num
}
}
impl Eq for Target {}
impl std::hash::Hash for Target {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.num.hash(state);
}
}
impl svc::Param<watch::Receiver<Routes>> for Target {
fn param(&self) -> watch::Receiver<Routes> {
self.routes.clone()
}
}
// === impl ConnectTcp ===
impl ConnectTcp {
fn add_server(&mut self, s: MockServer) {
self.srvs.lock().insert(s.addr, s);
}
}
impl<T: svc::Param<Remote<ServerAddr>>> svc::Service<T> for ConnectTcp {
type Response = (support::io::Mock, Local<ClientAddr>);
type Error = io::Error;
type Future = future::Ready<io::Result<Self::Response>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, t: T) -> Self::Future {
let Remote(ServerAddr(addr)) = t.param();
let mut mock = self
.srvs
.lock()
.remove(&addr)
.expect("tried to connect to an unexpected address");
assert_eq!(addr, mock.addr);
let local = Local(ClientAddr(addr));
future::ok::<_, support::io::Error>((mock.io.build(), local))
}
}
fn spawn_io(
client_hello: Vec<u8>,
) -> (
io::DuplexStream,
tokio::task::JoinHandle<io::Result<String>>,
) {
let (mut client_io, server_io) = io::duplex(100);
let task = tokio::spawn(async move {
client_io.write_all(&client_hello).await?;
client_io.write_all(REQUEST).await?;
let mut buf = String::with_capacity(100);
client_io.read_to_string(&mut buf).await?;
Ok(buf)
});
(server_io, task)
}
fn default_backend(addr: SocketAddr) -> client_policy::Backend {
use client_policy::{Backend, BackendDispatcher, EndpointMetadata, Meta, Queue};
Backend {
meta: Meta::new_default("test"),
queue: Queue {
capacity: 100,
failfast_timeout: Duration::from_secs(10),
},
dispatcher: BackendDispatcher::Forward(addr, EndpointMetadata::default()),
}
}
fn sni_route(backend: client_policy::Backend, sni: sni::MatchSni) -> client_policy::tls::Route {
use client_policy::{
tls::{Filter, Policy, Route, Rule},
Meta, RouteBackend, RouteDistribution,
};
use linkerd_tls_route::r#match::MatchSession;
use once_cell::sync::Lazy;
static NO_FILTERS: Lazy<Arc<[Filter]>> = Lazy::new(|| Arc::new([]));
Route {
snis: vec![sni],
rules: vec![Rule {
matches: vec![MatchSession::default()],
policy: Policy {
meta: Meta::new_default("test_route"),
filters: NO_FILTERS.clone(),
params: (),
distribution: RouteDistribution::FirstAvailable(Arc::new([RouteBackend {
filters: NO_FILTERS.clone(),
backend,
}])),
},
}],
}
}
// generates a sample ClientHello TLS message for testing
fn generate_client_hello(sni: &str) -> Vec<u8> {
use tokio_rustls::rustls::{
internal::msgs::{
base::Payload,
enums::Compression,
handshake::{
ClientExtension, ClientHelloPayload, HandshakeMessagePayload, HandshakePayload,
Random, SessionId,
},
message::{MessagePayload, PlainMessage},
},
server::DnsName,
CipherSuite, ContentType, HandshakeType, ProtocolVersion,
};
let sni = DnsName::try_from(sni.to_string()).unwrap();
let hs_payload = HandshakeMessagePayload {
typ: HandshakeType::ClientHello,
payload: HandshakePayload::ClientHello(ClientHelloPayload {
client_version: ProtocolVersion::TLSv1_2,
random: Random::from([0; 32]),
session_id: SessionId::empty(),
cipher_suites: vec![CipherSuite::TLS_NULL_WITH_NULL_NULL],
compression_methods: vec![Compression::Null],
extensions: vec![ClientExtension::make_sni(sni.borrow())],
}),
};
let mut hs_payload_bytes = Vec::default();
MessagePayload::handshake(hs_payload).encode(&mut hs_payload_bytes);
let message = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload(hs_payload_bytes),
};
message.into_unencrypted_opaque().encode()
}

View File

@ -0,0 +1,73 @@
use super::*;
use crate::tls::Tls;
use linkerd_app_core::{
svc::ServiceExt,
tls::{NewDetectRequiredSni, ServerName},
trace, NameAddr,
};
use linkerd_proxy_client_policy as client_policy;
use std::{net::SocketAddr, str::FromStr, sync::Arc};
use tokio::sync::watch;
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn routes() {
let _trace = trace::test::trace_init();
const AUTHORITY: &str = "logical.test.svc.cluster.local";
const PORT: u16 = 666;
let addr = SocketAddr::new([192, 0, 2, 41].into(), PORT);
let dest: NameAddr = format!("{AUTHORITY}:{PORT}")
.parse::<NameAddr>()
.expect("dest addr is valid");
let resolve = support::resolver().endpoint_exists(dest.clone(), addr, Default::default());
let (rt, _shutdown) = runtime();
let client_hello = generate_client_hello(AUTHORITY);
let (srv, io, rsp) = MockServer::new(addr, AUTHORITY, client_hello);
let mut connect = ConnectTcp::default();
connect.add_server(srv);
let stack = Outbound::new(default_config(), rt, &mut Default::default())
.with_stack(connect)
.push_tls_concrete(resolve)
.push_tls_logical()
.map_stack(|config, _rt, stk| {
stk.push_new_idle_cached(config.discovery_idle_timeout)
.push_map_target(|(sni, parent): (ServerName, _)| Tls { sni, parent })
.push(NewDetectRequiredSni::layer(Duration::from_secs(1)))
.arc_new_clone_tcp()
})
.into_inner();
let correct_backend = default_backend(addr);
let correct_route = sni_route(
correct_backend.clone(),
sni::MatchSni::Exact(AUTHORITY.into()),
);
let wrong_addr = SocketAddr::new([0, 0, 0, 0].into(), PORT);
let wrong_backend = default_backend(wrong_addr);
let wrong_route_1 = sni_route(
wrong_backend.clone(),
sni::MatchSni::from_str("foo").unwrap(),
);
let wrong_route_2 = sni_route(
wrong_backend.clone(),
sni::MatchSni::from_str("*.test.svc.cluster.local").unwrap(),
);
let (_route_tx, routes) = watch::channel(Routes {
addr: addr.into(),
backends: Arc::new([correct_backend, wrong_backend]),
routes: Arc::new([correct_route, wrong_route_1, wrong_route_2]),
meta: ParentRef(client_policy::Meta::new_default("parent")),
});
let target = Target { num: 1, routes };
let svc = stack.new_service(target);
svc.oneshot(io).await.unwrap();
let msg = rsp.await.unwrap().unwrap();
assert_eq!(msg, AUTHORITY);
}

View File

@ -1,4 +1,4 @@
use crate::{IoSlice, PeerAddr, Poll};
use crate::{IoSlice, Peek, PeerAddr, Poll};
use futures::ready;
use linkerd_errno::Errno;
use pin_project::pin_project;
@ -82,3 +82,10 @@ impl<T: PeerAddr, S> PeerAddr for SensorIo<T, S> {
self.io.peer_addr()
}
}
#[async_trait::async_trait]
impl<I: Peek + Send + Sync, S: Sensor + Sync> Peek for SensorIo<I, S> {
async fn peek(&self, buf: &mut [u8]) -> Result<usize> {
self.io.peek(buf).await
}
}

View File

@ -1,107 +0,0 @@
use crate::{
server::{detect_sni, DetectIo, Timeout},
ServerName,
};
use linkerd_error::Error;
use linkerd_io as io;
use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Service, ServiceExt};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use thiserror::Error;
use tokio::time;
use tracing::debug;
#[derive(Clone, Debug, Error)]
#[error("SNI detection timed out")]
pub struct SniDetectionTimeoutError;
#[derive(Clone, Debug, Error)]
#[error("Could not find SNI")]
pub struct NoSniFoundError;
#[derive(Clone, Debug)]
pub struct NewDetectSni<P, N> {
params: P,
inner: N,
}
#[derive(Clone, Debug)]
pub struct DetectSni<T, P, N> {
target: T,
inner: N,
timeout: Timeout,
params: P,
}
impl<P, N> NewDetectSni<P, N> {
pub fn new(params: P, inner: N) -> Self {
Self { inner, params }
}
pub fn layer(params: P) -> impl layer::Layer<N, Service = Self> + Clone
where
P: Clone,
{
layer::mk(move |inner| Self::new(params.clone(), inner))
}
}
impl<T, P, N> NewService<T> for NewDetectSni<P, N>
where
P: ExtractParam<Timeout, T> + Clone,
N: Clone,
{
type Service = DetectSni<T, P, N>;
fn new_service(&self, target: T) -> Self::Service {
let timeout = self.params.extract_param(&target);
DetectSni {
target,
timeout,
inner: self.inner.clone(),
params: self.params.clone(),
}
}
}
impl<T, P, I, N, S> Service<I> for DetectSni<T, P, N>
where
T: Clone + Send + Sync + 'static,
P: InsertParam<ServerName, T> + Clone + Send + Sync + 'static,
P::Target: Send + 'static,
I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static,
N: NewService<P::Target, Service = S> + Clone + Send + 'static,
S: Service<DetectIo<I>> + Send,
S::Error: Into<Error>,
S::Future: Send,
{
type Response = S::Response;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<S::Response, Error>> + Send + 'static>>;
#[inline]
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, io: I) -> Self::Future {
let target = self.target.clone();
let new_accept = self.inner.clone();
let params = self.params.clone();
// Detect the SNI from a ClientHello (or timeout).
let Timeout(timeout) = self.timeout;
let detect = time::timeout(timeout, detect_sni(io));
Box::pin(async move {
let (sni, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??;
let sni = sni.ok_or(NoSniFoundError)?;
debug!("detected SNI: {:?}", sni);
let svc = new_accept.new_service(params.insert_param(sni, target));
svc.oneshot(io).await.map_err(Into::into)
})
}
}

View File

@ -2,12 +2,14 @@
#![forbid(unsafe_code)]
pub mod client;
pub mod detect_sni;
pub mod server;
pub use self::{
client::{Client, ClientTls, ConditionalClientTls, ConnectMeta, NoClientTls, ServerId},
server::{ClientId, ConditionalServerTls, NewDetectTls, NoServerTls, ServerTls},
server::{
ClientId, ConditionalServerTls, NewDetectRequiredSni, NewDetectTls, NoServerTls,
NoSniFoundError, ServerTls, SniDetectionTimeoutError,
},
};
use linkerd_dns_name as dns;

View File

@ -1,4 +1,5 @@
mod client_hello;
mod required_sni;
use crate::{NegotiatedProtocol, ServerName};
use bytes::BytesMut;
@ -18,6 +19,8 @@ use thiserror::Error;
use tokio::time::{self, Duration};
use tracing::{debug, trace, warn};
pub use self::required_sni::{NewDetectRequiredSni, NoSniFoundError, SniDetectionTimeoutError};
/// Describes the authenticated identity of a remote client.
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct ClientId(pub id::Id);
@ -65,6 +68,7 @@ pub struct NewDetectTls<L, P, N> {
_local_identity: std::marker::PhantomData<fn() -> L>,
}
/// A param type used to indicate the timeout after which detection should fail.
#[derive(Copy, Clone, Debug)]
pub struct Timeout(pub Duration);

View File

@ -0,0 +1,118 @@
use crate::{
server::{detect_sni, DetectIo},
ServerName,
};
use linkerd_error::Error;
use linkerd_io as io;
use linkerd_stack::{layer, NewService, Service, ServiceExt};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use thiserror::Error;
use tokio::time;
use tracing::debug;
#[derive(Clone, Debug, Error)]
#[error("SNI detection timed out")]
pub struct SniDetectionTimeoutError;
#[derive(Clone, Debug, Error)]
#[error("Could not find SNI")]
pub struct NoSniFoundError;
/// A NewService that instruments an inner stack with knowledge of the
/// connection's TLS ServerName (i.e. from an SNI header).
///
/// This differs from the parent module's NewDetectTls in a a few ways:
///
/// - It requires that all connections have an SNI.
/// - It assumes that these connections may not be terminated locally, so there
/// is no concept of a local server name.
/// - There are no special affordances for mutually authenticated TLS, so we
/// make no attempt to detect the client's identity.
/// - The detection timeout is fixed and cannot vary per target (for
/// convenience, to reduce needless boilerplate).
#[derive(Clone, Debug)]
pub struct NewDetectRequiredSni<N> {
inner: N,
timeout: time::Duration,
}
#[derive(Clone, Debug)]
pub struct DetectRequiredSni<T, N> {
target: T,
inner: N,
timeout: time::Duration,
}
// === impl NewDetectRequiredSni ===
impl<N> NewDetectRequiredSni<N> {
fn new(timeout: time::Duration, inner: N) -> Self {
Self { inner, timeout }
}
pub fn layer(timeout: time::Duration) -> impl layer::Layer<N, Service = Self> + Clone {
layer::mk(move |inner| Self::new(timeout, inner))
}
}
impl<T, N> NewService<T> for NewDetectRequiredSni<N>
where
N: Clone,
{
type Service = DetectRequiredSni<T, N>;
fn new_service(&self, target: T) -> Self::Service {
DetectRequiredSni::new(self.timeout, target, self.inner.clone())
}
}
// === impl DetectRequiredSni ===
impl<T, N> DetectRequiredSni<T, N> {
fn new(timeout: time::Duration, target: T, inner: N) -> Self {
Self {
target,
inner,
timeout,
}
}
}
impl<T, I, N, S> Service<I> for DetectRequiredSni<T, N>
where
T: Clone + Send + Sync + 'static,
I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static,
N: NewService<(ServerName, T), Service = S> + Clone + Send + 'static,
S: Service<DetectIo<I>> + Send,
S::Error: Into<Error>,
S::Future: Send,
{
type Response = S::Response;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<S::Response, Error>> + Send + 'static>>;
#[inline]
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, io: I) -> Self::Future {
let target = self.target.clone();
let new_accept = self.inner.clone();
// Detect the SNI from a ClientHello (or timeout).
let detect = time::timeout(self.timeout, detect_sni(io));
Box::pin(async move {
let (res, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??;
let sni = res.ok_or(NoSniFoundError)?;
debug!(?sni, "Detected TLS");
let svc = new_accept.new_service((sni, target));
svc.oneshot(io).await.map_err(Into::into)
})
}
}