feat(outbound): Route TLS connections by SNI (#3160)
This is a draft PR that wires together some of the recently introduced changes to the proxy in order to deliver TLS routing functionality. The changes are: This change adds a TLS routing stack that has the following properties: - it always expects that `SNI` is present - uses tls routing information provided by the discovery API to perform routing based on SNI - does not concern itself with terminating TLS by simply proxies the encrypted stream Signed-off-by: Zahari Dichev <zaharidichev@gmail.com>
This commit is contained in:
parent
b65c43aac5
commit
73e96ddb12
|
|
@ -1341,6 +1341,7 @@ dependencies = [
|
|||
"linkerd-proxy-client-policy",
|
||||
"linkerd-retry",
|
||||
"linkerd-stack",
|
||||
"linkerd-tls-route",
|
||||
"linkerd-tonic-stream",
|
||||
"linkerd-tonic-watch",
|
||||
"linkerd-tracing",
|
||||
|
|
@ -1351,6 +1352,7 @@ dependencies = [
|
|||
"prometheus-client",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-test",
|
||||
"tonic",
|
||||
"tower",
|
||||
|
|
|
|||
|
|
@ -44,12 +44,14 @@ linkerd-proxy-client-policy = { path = "../../proxy/client-policy", features = [
|
|||
"proto",
|
||||
] }
|
||||
linkerd-retry = { path = "../../retry" }
|
||||
linkerd-tls-route = { path = "../../tls/route" }
|
||||
linkerd-tonic-stream = { path = "../../tonic-stream" }
|
||||
linkerd-tonic-watch = { path = "../../tonic-watch" }
|
||||
|
||||
[dev-dependencies]
|
||||
hyper = { version = "0.14", features = ["http1", "http2"] }
|
||||
tokio = { version = "1", features = ["macros", "sync", "time"] }
|
||||
tokio-rustls = "0.24"
|
||||
tokio-test = "0.4"
|
||||
tower-test = "0.4"
|
||||
|
||||
|
|
|
|||
|
|
@ -33,10 +33,13 @@ pub use self::balance::BalancerMetrics;
|
|||
pub enum Dispatch {
|
||||
Balance(NameAddr, EwmaConfig),
|
||||
Forward(Remote<ServerAddr>, Metadata),
|
||||
Fail { message: Arc<str> },
|
||||
/// A backend dispatcher that explicitly fails all requests.
|
||||
Fail {
|
||||
message: Arc<str>,
|
||||
},
|
||||
}
|
||||
|
||||
/// A backend dispatcher explicitly fails all requests.
|
||||
/// A backend dispatcher that explicitly fails all requests.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{0}")]
|
||||
pub struct DispatcherFailed(Arc<str>);
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ use linkerd_app_core::{
|
|||
tap,
|
||||
},
|
||||
svc::{self, ServiceExt},
|
||||
tls,
|
||||
tls::ConnectMeta as TlsConnectMeta,
|
||||
transport::addrs::*,
|
||||
AddrMatch, Error, ProxyRuntime,
|
||||
};
|
||||
|
|
@ -46,6 +46,7 @@ mod sidecar;
|
|||
pub mod tcp;
|
||||
#[cfg(any(test, feature = "test-util"))]
|
||||
pub mod test_util;
|
||||
pub mod tls;
|
||||
mod zone;
|
||||
|
||||
pub use self::discover::{spawn_synthesized_profile_policy, synthesize_forward_policy, Discovery};
|
||||
|
|
@ -100,7 +101,7 @@ struct Runtime {
|
|||
drain: drain::Watch,
|
||||
}
|
||||
|
||||
pub type ConnectMeta = tls::ConnectMeta<Local<ClientAddr>>;
|
||||
pub type ConnectMeta = TlsConnectMeta<Local<ClientAddr>>;
|
||||
|
||||
/// A reference to a frontend/apex resource, usually a service.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ pub struct OutboundMetrics {
|
|||
pub(crate) struct PromMetrics {
|
||||
pub(crate) http: crate::http::HttpMetrics,
|
||||
pub(crate) opaq: crate::opaq::OpaqMetrics,
|
||||
pub(crate) tls: crate::tls::TlsMetrics,
|
||||
pub(crate) zone: crate::zone::TcpZoneMetrics,
|
||||
}
|
||||
|
||||
|
|
@ -92,8 +93,14 @@ impl PromMetrics {
|
|||
|
||||
let opaq = crate::opaq::OpaqMetrics::register(registry.sub_registry_with_prefix("tcp"));
|
||||
let zone = crate::zone::TcpZoneMetrics::register(registry.sub_registry_with_prefix("tcp"));
|
||||
let tls = crate::tls::TlsMetrics::register(registry.sub_registry_with_prefix("tls"));
|
||||
|
||||
Self { http, opaq, zone }
|
||||
Self {
|
||||
http,
|
||||
opaq,
|
||||
tls,
|
||||
zone,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ pub enum Protocol {
|
|||
Http2,
|
||||
Detect,
|
||||
Opaque,
|
||||
Tls,
|
||||
}
|
||||
|
||||
// === impl Outbound ===
|
||||
|
|
@ -29,6 +30,7 @@ impl<N> Outbound<N> {
|
|||
pub fn push_protocol<T, I, NSvc>(
|
||||
self,
|
||||
http: svc::ArcNewCloneHttp<Http<T>>,
|
||||
tls: svc::ArcNewCloneTcp<T, io::EitherIo<I, io::PrefixedIo<I>>>,
|
||||
) -> Outbound<svc::ArcNewTcp<T, I>>
|
||||
where
|
||||
// Target type indicating whether detection should be skipped.
|
||||
|
|
@ -83,7 +85,14 @@ impl<N> Outbound<N> {
|
|||
http.map_stack(|_, _, http| {
|
||||
// First separate traffic that needs protocol detection. Then switch
|
||||
// between traffic that is known to be HTTP or opaque.
|
||||
http.push_switch(Ok::<_, Infallible>, opaq.clone().into_inner())
|
||||
let known = http.push_switch(
|
||||
Ok::<_, Infallible>,
|
||||
opaq.clone()
|
||||
.push_switch(Ok::<_, Infallible>, tls.clone())
|
||||
.into_inner(),
|
||||
);
|
||||
|
||||
known
|
||||
.push_on_service(svc::MapTargetLayer::new(io::EitherIo::Left))
|
||||
.push_switch(
|
||||
|parent: T| -> Result<_, Infallible> {
|
||||
|
|
@ -96,7 +105,12 @@ impl<N> Outbound<N> {
|
|||
version: http::Version::H2,
|
||||
parent,
|
||||
}))),
|
||||
Protocol::Opaque => Ok(svc::Either::A(svc::Either::B(parent))),
|
||||
Protocol::Opaque => {
|
||||
Ok(svc::Either::A(svc::Either::B(svc::Either::A(parent))))
|
||||
}
|
||||
Protocol::Tls => {
|
||||
Ok(svc::Either::A(svc::Either::B(svc::Either::B(parent))))
|
||||
}
|
||||
Protocol::Detect => Ok(svc::Either::B(parent)),
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
http, opaq, policy,
|
||||
protocol::{self, Protocol},
|
||||
Discovery, Outbound, ParentRef,
|
||||
tls, Discovery, Outbound, ParentRef,
|
||||
};
|
||||
use linkerd_app_core::{
|
||||
io, profiles,
|
||||
|
|
@ -32,6 +32,12 @@ struct HttpSidecar {
|
|||
routes: watch::Receiver<http::Routes>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TlsSidecar {
|
||||
orig_dst: OrigDstAddr,
|
||||
routes: watch::Receiver<tls::Routes>,
|
||||
}
|
||||
|
||||
// === impl Outbound ===
|
||||
|
||||
impl Outbound<()> {
|
||||
|
|
@ -53,6 +59,12 @@ impl Outbound<()> {
|
|||
R::Resolution: Unpin,
|
||||
{
|
||||
let opaq = self.to_tcp_connect().push_opaq_cached(resolve.clone());
|
||||
let tls = self
|
||||
.to_tcp_connect()
|
||||
.push_tls_cached(resolve.clone())
|
||||
.into_stack()
|
||||
.push_map_target(TlsSidecar::from)
|
||||
.arc_new_clone_tcp();
|
||||
|
||||
let http = self
|
||||
.to_tcp_connect()
|
||||
|
|
@ -64,7 +76,8 @@ impl Outbound<()> {
|
|||
.push_map_target(HttpSidecar::from)
|
||||
.arc_new_clone_http();
|
||||
|
||||
opaq.push_protocol(http.into_inner())
|
||||
opaq.clone()
|
||||
.push_protocol(http.into_inner(), tls.into_inner())
|
||||
// Use a dedicated target type to bind discovery results to the
|
||||
// outbound sidecar stack configuration.
|
||||
.map_stack(move |_, _, stk| stk.push_map_target(Sidecar::from))
|
||||
|
|
@ -131,7 +144,8 @@ impl svc::Param<Protocol> for Sidecar {
|
|||
match self.policy.borrow().protocol {
|
||||
policy::Protocol::Http1(_) => Protocol::Http1,
|
||||
policy::Protocol::Http2(_) | policy::Protocol::Grpc(_) => Protocol::Http2,
|
||||
policy::Protocol::Opaque(_) | policy::Protocol::Tls(_) => Protocol::Opaque,
|
||||
policy::Protocol::Opaque(_) => Protocol::Opaque,
|
||||
policy::Protocol::Tls(_) => Protocol::Tls,
|
||||
policy::Protocol::Detect { .. } => Protocol::Detect,
|
||||
}
|
||||
}
|
||||
|
|
@ -326,3 +340,62 @@ impl std::hash::Hash for HttpSidecar {
|
|||
self.version.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
// === impl TlsSidecar ===
|
||||
|
||||
impl From<Sidecar> for TlsSidecar {
|
||||
fn from(parent: Sidecar) -> Self {
|
||||
let orig_dst = parent.orig_dst;
|
||||
let mut policy = parent.policy.clone();
|
||||
|
||||
let init = Self::mk_policy_routes(orig_dst, &policy.borrow_and_update())
|
||||
.expect("initial policy must be tls");
|
||||
let routes = tls::spawn_routes(policy, init, move |policy: &policy::ClientPolicy| {
|
||||
Self::mk_policy_routes(orig_dst, policy)
|
||||
});
|
||||
TlsSidecar { orig_dst, routes }
|
||||
}
|
||||
}
|
||||
|
||||
impl TlsSidecar {
|
||||
fn mk_policy_routes(
|
||||
OrigDstAddr(orig_dst): OrigDstAddr,
|
||||
policy: &policy::ClientPolicy,
|
||||
) -> Option<tls::Routes> {
|
||||
let parent_ref = ParentRef(policy.parent.clone());
|
||||
let routes = match policy.protocol {
|
||||
policy::Protocol::Tls(policy::tls::Tls { ref routes }) => routes.clone(),
|
||||
_ => {
|
||||
tracing::info!("Ignoring a discovery update that changed a route from TLS");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
Some(tls::Routes {
|
||||
addr: orig_dst.into(),
|
||||
meta: parent_ref,
|
||||
routes,
|
||||
backends: policy.backends.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl svc::Param<watch::Receiver<tls::Routes>> for TlsSidecar {
|
||||
fn param(&self) -> watch::Receiver<tls::Routes> {
|
||||
self.routes.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::cmp::PartialEq for TlsSidecar {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.orig_dst == other.orig_dst
|
||||
}
|
||||
}
|
||||
|
||||
impl std::cmp::Eq for TlsSidecar {}
|
||||
|
||||
impl std::hash::Hash for TlsSidecar {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.orig_dst.hash(state);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,134 @@
|
|||
use crate::{tcp, Outbound};
|
||||
use linkerd_app_core::{
|
||||
io,
|
||||
metrics::prom,
|
||||
proxy::{
|
||||
api_resolve::{ConcreteAddr, Metadata},
|
||||
core::Resolve,
|
||||
},
|
||||
svc,
|
||||
tls::{NewDetectRequiredSni, ServerName},
|
||||
transport::addrs::*,
|
||||
Error,
|
||||
};
|
||||
use std::{fmt::Debug, hash::Hash};
|
||||
use tokio::sync::watch;
|
||||
|
||||
mod concrete;
|
||||
mod logical;
|
||||
|
||||
pub use self::logical::{Concrete, Routes};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
struct Tls<T> {
|
||||
sni: ServerName,
|
||||
parent: T,
|
||||
}
|
||||
|
||||
pub fn spawn_routes<T>(
|
||||
mut route_rx: watch::Receiver<T>,
|
||||
init: Routes,
|
||||
mut mk: impl FnMut(&T) -> Option<Routes> + Send + Sync + 'static,
|
||||
) -> watch::Receiver<Routes>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
{
|
||||
let (tx, rx) = watch::channel(init);
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let res = tokio::select! {
|
||||
biased;
|
||||
_ = tx.closed() => return,
|
||||
res = route_rx.changed() => res,
|
||||
};
|
||||
|
||||
if res.is_err() {
|
||||
// Drop the `tx` sender when the profile sender is
|
||||
// dropped.
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(routes) = (mk)(&*route_rx.borrow_and_update()) {
|
||||
if tx.send(routes).is_err() {
|
||||
// Drop the `tx` sender when all of its receivers are dropped.
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct TlsMetrics {
|
||||
balance: concrete::BalancerMetrics,
|
||||
}
|
||||
|
||||
// === impl Outbound ===
|
||||
|
||||
impl<C> Outbound<C> {
|
||||
/// Builds a stack that proxies TLS connections.
|
||||
///
|
||||
/// This stack uses caching so that a router/load-balancer may be reused
|
||||
/// across multiple connections.
|
||||
pub fn push_tls_cached<T, I, R>(self, resolve: R) -> Outbound<svc::ArcNewCloneTcp<T, I>>
|
||||
where
|
||||
// Tls target
|
||||
T: Clone + Debug + PartialEq + Eq + Hash + Send + Sync + 'static,
|
||||
T: svc::Param<watch::Receiver<Routes>>,
|
||||
// Server-side connection
|
||||
I: io::AsyncRead + io::AsyncWrite + io::PeerAddr + io::Peek,
|
||||
I: Debug + Send + Sync + Unpin + 'static,
|
||||
// Endpoint discovery
|
||||
R: Resolve<ConcreteAddr, Endpoint = Metadata, Error = Error>,
|
||||
R::Resolution: Unpin,
|
||||
// TCP endpoint stack.
|
||||
C: svc::MakeConnection<tcp::Connect, Metadata = Local<ClientAddr>, Error = io::Error>,
|
||||
C: Clone + Send + Sync + Unpin + 'static,
|
||||
C::Connection: Send + Unpin,
|
||||
C::Future: Send + Unpin,
|
||||
{
|
||||
self.push_tcp_endpoint()
|
||||
.push_tls_concrete(resolve)
|
||||
.push_tls_logical()
|
||||
.map_stack(|config, _rt, stk| {
|
||||
stk.push_new_idle_cached(config.discovery_idle_timeout)
|
||||
// Use a dedicated target type to configure parameters for
|
||||
// the TLS stack. It also helps narrow the cache key.
|
||||
.push_map_target(|(sni, parent): (ServerName, T)| Tls { sni, parent })
|
||||
.push(NewDetectRequiredSni::layer(
|
||||
config.proxy.detect_protocol_timeout,
|
||||
))
|
||||
.arc_new_clone_tcp()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// === impl Tls ===
|
||||
|
||||
impl<T> svc::Param<ServerName> for Tls<T> {
|
||||
fn param(&self) -> ServerName {
|
||||
self.sni.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<watch::Receiver<logical::Routes>> for Tls<T>
|
||||
where
|
||||
T: svc::Param<watch::Receiver<logical::Routes>>,
|
||||
{
|
||||
fn param(&self) -> watch::Receiver<logical::Routes> {
|
||||
self.parent.param()
|
||||
}
|
||||
}
|
||||
|
||||
// === impl TlsMetrics ===
|
||||
|
||||
impl TlsMetrics {
|
||||
pub fn register(registry: &mut prom::Registry) -> Self {
|
||||
let balance =
|
||||
concrete::BalancerMetrics::register(registry.sub_registry_with_prefix("balancer"));
|
||||
Self { balance }
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,387 @@
|
|||
use crate::{
|
||||
metrics::BalancerMetricsParams,
|
||||
stack_labels,
|
||||
zone::{tcp_zone_labels, TcpZoneLabels},
|
||||
BackendRef, Outbound, ParentRef,
|
||||
};
|
||||
use linkerd_app_core::{
|
||||
config::QueueConfig,
|
||||
drain, io,
|
||||
metrics::{
|
||||
self,
|
||||
prom::{self, EncodeLabelSetMut},
|
||||
OutboundZoneLocality,
|
||||
},
|
||||
proxy::{
|
||||
api_resolve::{ConcreteAddr, Metadata},
|
||||
core::Resolve,
|
||||
http::AuthorityOverride,
|
||||
tcp::{self, balance},
|
||||
},
|
||||
svc::{self, layer::Layer},
|
||||
tls::{self, ServerName},
|
||||
transport::{self, addrs::*},
|
||||
transport_header::SessionProtocol,
|
||||
Error, Infallible, NameAddr,
|
||||
};
|
||||
use std::{fmt::Debug, net::SocketAddr, sync::Arc};
|
||||
use tracing::info_span;
|
||||
|
||||
/// Parameter configuring dispatcher behavior.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum Dispatch {
|
||||
Balance(NameAddr, balance::EwmaConfig),
|
||||
Forward(Remote<ServerAddr>, Metadata),
|
||||
/// A backend dispatcher that explicitly fails all requests.
|
||||
Fail {
|
||||
message: Arc<str>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{0}")]
|
||||
pub struct DispatcherFailed(Arc<str>);
|
||||
|
||||
/// Wraps errors encountered in this module.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("concrete service {addr}: {source}")]
|
||||
pub struct ConcreteError {
|
||||
addr: NameAddr,
|
||||
#[source]
|
||||
source: Error,
|
||||
}
|
||||
|
||||
/// Inner stack target type.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct Endpoint<T> {
|
||||
addr: Remote<ServerAddr>,
|
||||
is_local: bool,
|
||||
metadata: Metadata,
|
||||
parent: T,
|
||||
}
|
||||
|
||||
pub type BalancerMetrics = BalancerMetricsParams<ConcreteLabels>;
|
||||
|
||||
/// A target configuring a load balancer stack.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
struct Balance<T> {
|
||||
concrete: NameAddr,
|
||||
ewma: balance::EwmaConfig,
|
||||
queue: QueueConfig,
|
||||
parent: T,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
pub struct ConcreteLabels {
|
||||
concrete: Arc<str>,
|
||||
}
|
||||
|
||||
impl prom::EncodeLabelSetMut for ConcreteLabels {
|
||||
fn encode_label_set(&self, enc: &mut prom::encoding::LabelSetEncoder<'_>) -> std::fmt::Result {
|
||||
use prom::encoding::EncodeLabel;
|
||||
|
||||
("concrete", &*self.concrete).encode(enc.encode_label())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl prom::encoding::EncodeLabelSet for ConcreteLabels {
|
||||
fn encode(&self, mut enc: prom::encoding::LabelSetEncoder<'_>) -> std::fmt::Result {
|
||||
self.encode_label_set(&mut enc)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::ExtractParam<balance::Metrics, Balance<T>> for BalancerMetricsParams<ConcreteLabels> {
|
||||
fn extract_param(&self, bal: &Balance<T>) -> balance::Metrics {
|
||||
self.metrics(&ConcreteLabels {
|
||||
concrete: bal.concrete.to_string().into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// === impl Outbound ===
|
||||
|
||||
impl<C> Outbound<C> {
|
||||
/// Builds a [`svc::NewService`] stack that builds buffered tls services
|
||||
/// for `T`-typed concrete targets. Connections may be load balanced across
|
||||
/// a discovered set of replicas or forwarded to a single endpoint,
|
||||
/// depending on the value of the `Dispatch` parameter.
|
||||
///
|
||||
/// When a balancer has no available inner services, it goes into
|
||||
/// 'failfast'. While in failfast, buffered requests are failed and the
|
||||
/// service becomes unavailable so callers may choose alternate concrete
|
||||
/// services.
|
||||
pub fn push_tls_concrete<T, I, R>(
|
||||
self,
|
||||
resolve: R,
|
||||
) -> Outbound<
|
||||
svc::ArcNewService<
|
||||
T,
|
||||
impl svc::Service<I, Response = (), Error = Error, Future = impl Send> + Clone,
|
||||
>,
|
||||
>
|
||||
where
|
||||
// Logical target
|
||||
T: svc::Param<Dispatch>,
|
||||
T: Clone + Debug + Send + Sync + 'static,
|
||||
T: svc::Param<ServerName>,
|
||||
// Server-side socket.
|
||||
I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static,
|
||||
// Endpoint resolution.
|
||||
R: Resolve<ConcreteAddr, Endpoint = Metadata, Error = Error>,
|
||||
R::Resolution: Unpin,
|
||||
// Endpoint connector.
|
||||
C: svc::MakeConnection<Endpoint<T>> + Clone + Send + 'static,
|
||||
C::Connection: Send + Unpin,
|
||||
C::Metadata: Send + Unpin,
|
||||
C::Future: Send,
|
||||
C: Send + Sync + 'static,
|
||||
{
|
||||
let resolve =
|
||||
svc::MapTargetLayer::new(|t: Balance<T>| -> ConcreteAddr { ConcreteAddr(t.concrete) })
|
||||
.layer(resolve.into_service());
|
||||
|
||||
self.map_stack(|config, rt, inner| {
|
||||
let queue = config.tcp_connection_queue;
|
||||
|
||||
let connect = inner
|
||||
.push(svc::stack::WithoutConnectionMetadata::layer())
|
||||
.push_new_thunk();
|
||||
|
||||
let forward = connect
|
||||
.clone()
|
||||
.push_on_service(rt.metrics.proxy.stack.layer(stack_labels("tls", "forward")))
|
||||
.instrument(|e: &Endpoint<T>| info_span!("forward", addr = %e.addr));
|
||||
|
||||
let endpoint = connect
|
||||
.push_on_service(
|
||||
rt.metrics
|
||||
.proxy
|
||||
.stack
|
||||
.layer(stack_labels("tls", "endpoint")),
|
||||
)
|
||||
.instrument(|e: &Endpoint<T>| info_span!("endpoint", addr = %e.addr));
|
||||
|
||||
let fail = svc::ArcNewService::new(|message: Arc<str>| {
|
||||
svc::mk(move |_| futures::future::ready(Err(DispatcherFailed(message.clone()))))
|
||||
});
|
||||
|
||||
let inbound_ips = config.inbound_ips.clone();
|
||||
let balance = endpoint
|
||||
.push_map_target(
|
||||
move |((addr, metadata), target): ((SocketAddr, Metadata), Balance<T>)| {
|
||||
tracing::trace!(%addr, ?metadata, ?target, "Resolved endpoint");
|
||||
let is_local = inbound_ips.contains(&addr.ip());
|
||||
Endpoint {
|
||||
addr: Remote(ServerAddr(addr)),
|
||||
metadata,
|
||||
is_local,
|
||||
parent: target.parent,
|
||||
}
|
||||
},
|
||||
)
|
||||
.lift_new_with_target()
|
||||
.push(tcp::NewBalance::layer(
|
||||
resolve,
|
||||
rt.metrics.prom.tls.balance.clone(),
|
||||
))
|
||||
.push(svc::NewMapErr::layer_from_target::<ConcreteError, _>())
|
||||
.push_on_service(rt.metrics.proxy.stack.layer(stack_labels("tls", "balance")))
|
||||
.instrument(|t: &Balance<T>| info_span!("balance", addr = %t.concrete));
|
||||
|
||||
balance
|
||||
.push_switch(Ok::<_, Infallible>, forward.into_inner())
|
||||
.push_switch(
|
||||
move |parent: T| -> Result<_, Infallible> {
|
||||
Ok(match parent.param() {
|
||||
Dispatch::Balance(concrete, ewma) => {
|
||||
svc::Either::A(svc::Either::A(Balance {
|
||||
concrete,
|
||||
ewma,
|
||||
queue,
|
||||
parent,
|
||||
}))
|
||||
}
|
||||
|
||||
Dispatch::Forward(addr, meta) => {
|
||||
svc::Either::A(svc::Either::B(Endpoint {
|
||||
addr,
|
||||
is_local: false,
|
||||
metadata: meta,
|
||||
parent,
|
||||
}))
|
||||
}
|
||||
Dispatch::Fail { message } => svc::Either::B(message),
|
||||
})
|
||||
},
|
||||
svc::stack(fail).check_new_clone().into_inner(),
|
||||
)
|
||||
.push_on_service(tcp::Forward::layer())
|
||||
.push_on_service(drain::Retain::layer(rt.drain.clone()))
|
||||
.push(svc::ArcNewService::layer())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// === impl ConcreteError ===
|
||||
|
||||
impl<T> From<(&Balance<T>, Error)> for ConcreteError {
|
||||
fn from((target, source): (&Balance<T>, Error)) -> Self {
|
||||
Self {
|
||||
addr: target.concrete.clone(),
|
||||
source,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === impl Balance ===
|
||||
|
||||
impl<T> std::ops::Deref for Balance<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.parent
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<balance::EwmaConfig> for Balance<T> {
|
||||
fn param(&self) -> balance::EwmaConfig {
|
||||
self.ewma
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<svc::queue::Capacity> for Balance<T> {
|
||||
fn param(&self) -> svc::queue::Capacity {
|
||||
svc::queue::Capacity(self.queue.capacity)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<svc::queue::Timeout> for Balance<T> {
|
||||
fn param(&self) -> svc::queue::Timeout {
|
||||
svc::queue::Timeout(self.queue.failfast_timeout)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: svc::Param<ParentRef>> svc::Param<ParentRef> for Balance<T> {
|
||||
fn param(&self) -> ParentRef {
|
||||
self.parent.param()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: svc::Param<BackendRef>> svc::Param<BackendRef> for Balance<T> {
|
||||
fn param(&self) -> BackendRef {
|
||||
self.parent.param()
|
||||
}
|
||||
}
|
||||
|
||||
// === impl Endpoint ===
|
||||
|
||||
impl<T> svc::Param<Remote<ServerAddr>> for Endpoint<T> {
|
||||
fn param(&self) -> Remote<ServerAddr> {
|
||||
self.addr
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<Option<crate::tcp::tagged_transport::PortOverride>> for Endpoint<T> {
|
||||
fn param(&self) -> Option<crate::tcp::tagged_transport::PortOverride> {
|
||||
if self.is_local {
|
||||
return None;
|
||||
}
|
||||
self.metadata
|
||||
.tagged_transport_port()
|
||||
.map(crate::tcp::tagged_transport::PortOverride)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<Option<AuthorityOverride>> for Endpoint<T> {
|
||||
fn param(&self) -> Option<AuthorityOverride> {
|
||||
if self.is_local {
|
||||
return None;
|
||||
}
|
||||
self.metadata
|
||||
.authority_override()
|
||||
.cloned()
|
||||
.map(AuthorityOverride)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<Option<SessionProtocol>> for Endpoint<T> {
|
||||
fn param(&self) -> Option<SessionProtocol> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<transport::labels::Key> for Endpoint<T>
|
||||
where
|
||||
T: svc::Param<ServerName>,
|
||||
{
|
||||
fn param(&self) -> transport::labels::Key {
|
||||
transport::labels::Key::OutboundClient(self.param())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<metrics::OutboundEndpointLabels> for Endpoint<T>
|
||||
where
|
||||
T: svc::Param<ServerName>,
|
||||
{
|
||||
fn param(&self) -> metrics::OutboundEndpointLabels {
|
||||
metrics::OutboundEndpointLabels {
|
||||
authority: None,
|
||||
labels: metrics::prefix_labels("dst", self.metadata.labels().iter()),
|
||||
zone_locality: self.param(),
|
||||
server_id: self.param(),
|
||||
target_addr: self.addr.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<OutboundZoneLocality> for Endpoint<T> {
|
||||
fn param(&self) -> OutboundZoneLocality {
|
||||
OutboundZoneLocality::new(&self.metadata)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<TcpZoneLabels> for Endpoint<T> {
|
||||
fn param(&self) -> TcpZoneLabels {
|
||||
tcp_zone_labels(self.param())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<metrics::EndpointLabels> for Endpoint<T>
|
||||
where
|
||||
T: svc::Param<ServerName>,
|
||||
{
|
||||
fn param(&self) -> metrics::EndpointLabels {
|
||||
metrics::EndpointLabels::from(svc::Param::<metrics::OutboundEndpointLabels>::param(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<tls::ConditionalClientTls> for Endpoint<T> {
|
||||
fn param(&self) -> tls::ConditionalClientTls {
|
||||
if self.is_local {
|
||||
return tls::ConditionalClientTls::None(tls::NoClientTls::Loopback);
|
||||
}
|
||||
|
||||
// If we're transporting an opaque protocol OR we're communicating with
|
||||
// a gateway, then set an ALPN value indicating support for a transport
|
||||
// header.
|
||||
let use_transport_header = self.metadata.tagged_transport_port().is_some()
|
||||
|| self.metadata.authority_override().is_some();
|
||||
self.metadata
|
||||
.identity()
|
||||
.cloned()
|
||||
.map(move |mut client_tls| {
|
||||
client_tls.alpn = if use_transport_header {
|
||||
use linkerd_app_core::transport_header::PROTOCOL;
|
||||
Some(tls::client::AlpnProtocols(vec![PROTOCOL.into()]))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
tls::ConditionalClientTls::Some(client_tls)
|
||||
})
|
||||
.unwrap_or(tls::ConditionalClientTls::None(
|
||||
tls::NoClientTls::NotProvidedByServiceDiscovery,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
use super::concrete;
|
||||
use crate::{BackendRef, Outbound, ParentRef};
|
||||
use linkerd_app_core::{io, svc, tls::ServerName, Addr, Error};
|
||||
use linkerd_proxy_client_policy as client_policy;
|
||||
use std::{fmt::Debug, hash::Hash, sync::Arc};
|
||||
use tokio::sync::watch;
|
||||
|
||||
pub mod route;
|
||||
pub mod router;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// Indicates the address used for logical routing.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct LogicalAddr(pub Addr);
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct Routes {
|
||||
pub addr: Addr,
|
||||
pub meta: ParentRef,
|
||||
pub routes: Arc<[client_policy::tls::Route]>,
|
||||
pub backends: Arc<[client_policy::Backend]>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct Concrete<T> {
|
||||
target: concrete::Dispatch,
|
||||
parent: T,
|
||||
parent_ref: ParentRef,
|
||||
backend_ref: BackendRef,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("no route")]
|
||||
pub struct NoRoute;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("logical service {addr}: {source}")]
|
||||
pub struct LogicalError {
|
||||
addr: Addr,
|
||||
#[source]
|
||||
source: Error,
|
||||
}
|
||||
|
||||
impl<N> Outbound<N> {
|
||||
/// Builds a `NewService` that produces a router service for each logical
|
||||
/// target.
|
||||
///
|
||||
/// The router uses discovery information (provided on the target) to
|
||||
/// support per-connection routing over a set of concrete inner services.
|
||||
/// Only available inner services are used for routing. When there are no
|
||||
/// available backends, requests are failed with a [`svc::stack::LoadShedError`].
|
||||
pub fn push_tls_logical<T, I, NSvc>(self) -> Outbound<svc::ArcNewCloneTcp<T, I>>
|
||||
where
|
||||
// Logical target.
|
||||
T: svc::Param<watch::Receiver<Routes>>,
|
||||
T: svc::Param<ServerName>,
|
||||
T: Eq + Hash + Clone + Debug + Send + Sync + 'static,
|
||||
// Concrete stack.
|
||||
I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static,
|
||||
// Concrete stack.
|
||||
N: svc::NewService<Concrete<T>, Service = NSvc> + Clone + Send + Sync + 'static,
|
||||
NSvc: svc::Service<I, Response = ()> + Clone + Send + Sync + 'static,
|
||||
NSvc::Future: Send,
|
||||
NSvc::Error: Into<Error>,
|
||||
{
|
||||
self.map_stack(|_config, _, concrete| {
|
||||
concrete
|
||||
.lift_new()
|
||||
.push_on_service(svc::layer::mk(move |concrete: N| {
|
||||
svc::stack(concrete.clone())
|
||||
.push(router::Router::layer())
|
||||
.push(svc::NewMapErr::layer_from_target::<LogicalError, _>())
|
||||
.arc_new_clone_tcp()
|
||||
.into_inner()
|
||||
}))
|
||||
// Rebuild the inner router stack every time the watch changes.
|
||||
.push(svc::NewSpawnWatch::<Routes, _>::layer_into::<
|
||||
router::Router<T>,
|
||||
>())
|
||||
.arc_new_clone_tcp()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// === impl LogicalError ===
|
||||
|
||||
impl<T> From<(&router::Router<T>, Error)> for LogicalError
|
||||
where
|
||||
T: Eq + Hash + Clone + Debug,
|
||||
{
|
||||
fn from((target, source): (&router::Router<T>, Error)) -> Self {
|
||||
let LogicalAddr(addr) = svc::Param::param(target);
|
||||
Self { addr, source }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<concrete::Dispatch> for Concrete<T> {
|
||||
fn param(&self) -> concrete::Dispatch {
|
||||
self.target.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<ServerName> for Concrete<T>
|
||||
where
|
||||
T: svc::Param<ServerName>,
|
||||
{
|
||||
fn param(&self) -> ServerName {
|
||||
self.parent.param()
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
use super::super::Concrete;
|
||||
use crate::{ParentRef, RouteRef};
|
||||
use linkerd_app_core::{io, svc, Addr, Error};
|
||||
use linkerd_distribute as distribute;
|
||||
use linkerd_tls_route as tls_route;
|
||||
use std::{fmt::Debug, hash::Hash};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct Backend<T> {
|
||||
pub(crate) route_ref: RouteRef,
|
||||
pub(crate) concrete: Concrete<T>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct MatchedRoute<T> {
|
||||
pub(super) r#match: tls_route::RouteMatch,
|
||||
pub(super) params: Route<T>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct Route<T> {
|
||||
pub(super) parent: T,
|
||||
pub(super) addr: Addr,
|
||||
pub(super) parent_ref: ParentRef,
|
||||
pub(super) route_ref: RouteRef,
|
||||
pub(super) distribution: BackendDistribution<T>,
|
||||
}
|
||||
|
||||
pub(crate) type BackendDistribution<T> = distribute::Distribution<Backend<T>>;
|
||||
pub(crate) type NewDistribute<T, N> = distribute::NewDistribute<Backend<T>, (), N>;
|
||||
|
||||
/// Wraps errors with route metadata.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("route {}: {source}", route.0)]
|
||||
struct RouteError {
|
||||
route: RouteRef,
|
||||
#[source]
|
||||
source: Error,
|
||||
}
|
||||
|
||||
// === impl Backend ===
|
||||
|
||||
impl<T: Clone> Clone for Backend<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
route_ref: self.route_ref.clone(),
|
||||
concrete: self.concrete.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === impl MatchedRoute ===
|
||||
|
||||
impl<T> MatchedRoute<T>
|
||||
where
|
||||
// Parent target.
|
||||
T: Debug + Eq + Hash,
|
||||
T: Clone + Send + Sync + 'static,
|
||||
{
|
||||
/// Builds a route stack that applies policy filters to requests and
|
||||
/// distributes requests over each route's backends. These [`Concrete`]
|
||||
/// backends are expected to be cached/shared by the inner stack.
|
||||
pub(crate) fn layer<N, I, NSvc>(
|
||||
) -> impl svc::Layer<N, Service = svc::ArcNewCloneTcp<Self, I>> + Clone
|
||||
where
|
||||
I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static,
|
||||
// Inner stack.
|
||||
N: svc::NewService<Concrete<T>, Service = NSvc> + Clone + Send + Sync + 'static,
|
||||
NSvc: svc::Service<I, Response = ()> + Clone + Send + Sync + 'static,
|
||||
NSvc::Future: Send,
|
||||
NSvc::Error: Into<Error>,
|
||||
{
|
||||
svc::layer::mk(move |inner| {
|
||||
svc::stack(inner)
|
||||
.push_map_target(|t| t)
|
||||
.push_map_target(|b: Backend<T>| b.concrete)
|
||||
.lift_new()
|
||||
.push(NewDistribute::layer())
|
||||
// The router does not take the backend's availability into
|
||||
// consideration, so we must eagerly fail requests to prevent
|
||||
// leaking tasks onto the runtime.
|
||||
.push_on_service(svc::LoadShed::layer())
|
||||
.push(svc::NewMapErr::layer_with(|rt: &Self| {
|
||||
let route = rt.params.route_ref.clone();
|
||||
move |source| RouteError {
|
||||
route: route.clone(),
|
||||
source,
|
||||
}
|
||||
}))
|
||||
.arc_new_clone_tcp()
|
||||
.into_inner()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone> svc::Param<BackendDistribution<T>> for MatchedRoute<T> {
|
||||
fn param(&self) -> BackendDistribution<T> {
|
||||
self.params.distribution.clone()
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
use super::{
|
||||
super::{concrete, Concrete},
|
||||
route, LogicalAddr, NoRoute,
|
||||
};
|
||||
use crate::{BackendRef, EndpointRef, RouteRef};
|
||||
use linkerd_app_core::{
|
||||
io, proxy::http, svc, tls::ServerName, transport::addrs::*, Addr, Error, NameAddr, Result,
|
||||
};
|
||||
use linkerd_distribute as distribute;
|
||||
use linkerd_proxy_client_policy as policy;
|
||||
use linkerd_tls_route as tls_route;
|
||||
use std::{fmt::Debug, hash::Hash, sync::Arc};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub(crate) struct Router<T: Clone + Debug + Eq + Hash> {
|
||||
pub(super) parent: T,
|
||||
pub(super) addr: Addr,
|
||||
pub(super) routes: Arc<[tls_route::Route<route::Route<T>>]>,
|
||||
pub(super) backends: distribute::Backends<Concrete<T>>,
|
||||
}
|
||||
|
||||
type NewBackendCache<T, N, S> = distribute::NewBackendCache<Concrete<T>, (), N, S>;
|
||||
|
||||
// === impl Router ===
|
||||
|
||||
impl<T> Router<T>
|
||||
where
|
||||
// Parent target type.
|
||||
T: Eq + Hash + Clone + Debug + Send + Sync + 'static,
|
||||
T: svc::Param<ServerName>,
|
||||
{
|
||||
pub fn layer<N, I, NSvc>() -> impl svc::Layer<N, Service = svc::ArcNewCloneTcp<Self, I>> + Clone
|
||||
where
|
||||
I: io::AsyncRead + io::AsyncWrite + Debug + Send + Unpin + 'static,
|
||||
// Concrete stack.
|
||||
N: svc::NewService<Concrete<T>, Service = NSvc> + Clone + Send + Sync + 'static,
|
||||
NSvc: svc::Service<I, Response = ()> + Clone + Send + Sync + 'static,
|
||||
NSvc::Future: Send,
|
||||
NSvc::Error: Into<Error>,
|
||||
{
|
||||
svc::layer::mk(move |inner| {
|
||||
svc::stack(inner)
|
||||
.lift_new()
|
||||
// Each route builds over concrete backends. All of these
|
||||
// backends are cached here and shared across routes.
|
||||
.push(NewBackendCache::layer())
|
||||
.push_on_service(route::MatchedRoute::layer())
|
||||
.push(svc::NewOneshotRoute::<Self, (), _>::layer_cached())
|
||||
.arc_new_clone_tcp()
|
||||
.into_inner()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<(crate::tls::Routes, T)> for Router<T>
|
||||
where
|
||||
T: Eq + Hash + Clone + Debug,
|
||||
{
|
||||
fn from((rts, parent): (crate::tls::Routes, T)) -> Self {
|
||||
let crate::tls::Routes {
|
||||
addr,
|
||||
meta: parent_ref,
|
||||
routes,
|
||||
backends,
|
||||
} = rts;
|
||||
|
||||
let mk_concrete = {
|
||||
let parent = parent.clone();
|
||||
let parent_ref = parent_ref.clone();
|
||||
|
||||
move |backend_ref: BackendRef, target: concrete::Dispatch| Concrete {
|
||||
target,
|
||||
parent: parent.clone(),
|
||||
backend_ref,
|
||||
parent_ref: parent_ref.clone(),
|
||||
}
|
||||
};
|
||||
|
||||
let mk_dispatch = move |bke: &policy::Backend| match bke.dispatcher {
|
||||
policy::BackendDispatcher::BalanceP2c(
|
||||
policy::Load::PeakEwma(policy::PeakEwma { decay, default_rtt }),
|
||||
policy::EndpointDiscovery::DestinationGet { ref path },
|
||||
) => mk_concrete(
|
||||
BackendRef(bke.meta.clone()),
|
||||
concrete::Dispatch::Balance(
|
||||
path.parse::<NameAddr>()
|
||||
.expect("destination must be a nameaddr"),
|
||||
http::balance::EwmaConfig { decay, default_rtt },
|
||||
),
|
||||
),
|
||||
policy::BackendDispatcher::Forward(addr, ref md) => mk_concrete(
|
||||
EndpointRef::new(md, addr.port().try_into().expect("port must not be 0")).into(),
|
||||
concrete::Dispatch::Forward(Remote(ServerAddr(addr)), md.clone()),
|
||||
),
|
||||
policy::BackendDispatcher::Fail { ref message } => mk_concrete(
|
||||
BackendRef(policy::Meta::new_default("fail")),
|
||||
concrete::Dispatch::Fail {
|
||||
message: message.clone(),
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
let mk_route_backend =
|
||||
|route_ref: &RouteRef, rb: &policy::RouteBackend<policy::tls::Filter>| {
|
||||
let concrete = mk_dispatch(&rb.backend);
|
||||
route::Backend {
|
||||
route_ref: route_ref.clone(),
|
||||
concrete,
|
||||
}
|
||||
};
|
||||
|
||||
let mk_distribution =
|
||||
|rr: &RouteRef, d: &policy::RouteDistribution<policy::tls::Filter>| match d {
|
||||
policy::RouteDistribution::Empty => route::BackendDistribution::Empty,
|
||||
policy::RouteDistribution::FirstAvailable(backends) => {
|
||||
route::BackendDistribution::first_available(
|
||||
backends.iter().map(|b| mk_route_backend(rr, b)),
|
||||
)
|
||||
}
|
||||
policy::RouteDistribution::RandomAvailable(backends) => {
|
||||
route::BackendDistribution::random_available(
|
||||
backends
|
||||
.iter()
|
||||
.map(|(rb, weight)| (mk_route_backend(rr, rb), *weight)),
|
||||
)
|
||||
.expect("distribution must be valid")
|
||||
}
|
||||
};
|
||||
|
||||
let mk_policy =
|
||||
|policy::RoutePolicy::<policy::tls::Filter, ()> {
|
||||
meta, distribution, ..
|
||||
}| {
|
||||
let route_ref = RouteRef(meta);
|
||||
let parent_ref = parent_ref.clone();
|
||||
|
||||
let distribution = mk_distribution(&route_ref, &distribution);
|
||||
route::Route {
|
||||
addr: addr.clone(),
|
||||
parent: parent.clone(),
|
||||
parent_ref: parent_ref.clone(),
|
||||
route_ref,
|
||||
distribution,
|
||||
}
|
||||
};
|
||||
|
||||
let routes = routes
|
||||
.iter()
|
||||
.map(|route| tls_route::Route {
|
||||
snis: route.snis.clone(),
|
||||
rules: route
|
||||
.rules
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|tls_route::Rule { matches, policy }| tls_route::Rule {
|
||||
matches,
|
||||
policy: mk_policy(policy),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let backends = backends.iter().map(mk_dispatch).collect();
|
||||
|
||||
Self {
|
||||
routes,
|
||||
backends,
|
||||
addr,
|
||||
parent,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, I> svc::router::SelectRoute<I> for Router<T>
|
||||
where
|
||||
T: Clone + Eq + Hash + Debug,
|
||||
T: svc::Param<ServerName>,
|
||||
{
|
||||
type Key = route::MatchedRoute<T>;
|
||||
type Error = NoRoute;
|
||||
|
||||
fn select(&self, _: &I) -> Result<Self::Key, Self::Error> {
|
||||
use linkerd_tls_route::SessionInfo;
|
||||
|
||||
let server_name: ServerName = self.parent.param();
|
||||
tracing::trace!("Selecting TLS route for {:?}", server_name);
|
||||
let si = SessionInfo { sni: server_name };
|
||||
let (r#match, params) = policy::tls::find(&self.routes, si).ok_or(NoRoute)?;
|
||||
tracing::debug!(meta = ?params.route_ref, "Selected route");
|
||||
tracing::trace!(?r#match);
|
||||
|
||||
Ok(route::MatchedRoute {
|
||||
r#match,
|
||||
params: params.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<LogicalAddr> for Router<T>
|
||||
where
|
||||
T: Eq + Hash + Clone + Debug,
|
||||
{
|
||||
fn param(&self) -> LogicalAddr {
|
||||
LogicalAddr(self.addr.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> svc::Param<distribute::Backends<Concrete<T>>> for Router<T>
|
||||
where
|
||||
T: Eq + Hash + Clone + Debug,
|
||||
{
|
||||
fn param(&self) -> distribute::Backends<Concrete<T>> {
|
||||
self.backends.clone()
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,213 @@
|
|||
use super::{Outbound, ParentRef, Routes};
|
||||
use crate::test_util::*;
|
||||
use linkerd_app_core::{
|
||||
io,
|
||||
svc::{self, NewService},
|
||||
transport::addrs::*,
|
||||
Result,
|
||||
};
|
||||
use linkerd_app_test::{AsyncReadExt, AsyncWriteExt};
|
||||
use linkerd_proxy_client_policy::{self as client_policy, tls::sni};
|
||||
use parking_lot::Mutex;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::SocketAddr,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::sync::watch;
|
||||
|
||||
mod basic;
|
||||
|
||||
const REQUEST: &[u8] = b"who r u?";
|
||||
type Reponse = tokio::task::JoinHandle<io::Result<String>>;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Target {
|
||||
num: usize,
|
||||
routes: watch::Receiver<Routes>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
||||
struct MockServer {
|
||||
io: support::io::Builder,
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct ConnectTcp {
|
||||
srvs: Arc<Mutex<HashMap<SocketAddr, MockServer>>>,
|
||||
}
|
||||
|
||||
// === impl MockServer ===
|
||||
|
||||
impl MockServer {
|
||||
fn new(
|
||||
addr: SocketAddr,
|
||||
service_name: &str,
|
||||
client_hello: Vec<u8>,
|
||||
) -> (Self, io::DuplexStream, Reponse) {
|
||||
let mut io = support::io();
|
||||
|
||||
io.write(&client_hello)
|
||||
.write(REQUEST)
|
||||
.read(service_name.as_bytes());
|
||||
|
||||
let server = MockServer { io, addr };
|
||||
let (io, response) = spawn_io(client_hello);
|
||||
|
||||
(server, io, response)
|
||||
}
|
||||
}
|
||||
|
||||
// === impl Target ===
|
||||
|
||||
impl PartialEq for Target {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.num == other.num
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Target {}
|
||||
|
||||
impl std::hash::Hash for Target {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.num.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl svc::Param<watch::Receiver<Routes>> for Target {
|
||||
fn param(&self) -> watch::Receiver<Routes> {
|
||||
self.routes.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// === impl ConnectTcp ===
|
||||
|
||||
impl ConnectTcp {
|
||||
fn add_server(&mut self, s: MockServer) {
|
||||
self.srvs.lock().insert(s.addr, s);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: svc::Param<Remote<ServerAddr>>> svc::Service<T> for ConnectTcp {
|
||||
type Response = (support::io::Mock, Local<ClientAddr>);
|
||||
type Error = io::Error;
|
||||
type Future = future::Ready<io::Result<Self::Response>>;
|
||||
|
||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, t: T) -> Self::Future {
|
||||
let Remote(ServerAddr(addr)) = t.param();
|
||||
let mut mock = self
|
||||
.srvs
|
||||
.lock()
|
||||
.remove(&addr)
|
||||
.expect("tried to connect to an unexpected address");
|
||||
|
||||
assert_eq!(addr, mock.addr);
|
||||
let local = Local(ClientAddr(addr));
|
||||
future::ok::<_, support::io::Error>((mock.io.build(), local))
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_io(
|
||||
client_hello: Vec<u8>,
|
||||
) -> (
|
||||
io::DuplexStream,
|
||||
tokio::task::JoinHandle<io::Result<String>>,
|
||||
) {
|
||||
let (mut client_io, server_io) = io::duplex(100);
|
||||
let task = tokio::spawn(async move {
|
||||
client_io.write_all(&client_hello).await?;
|
||||
client_io.write_all(REQUEST).await?;
|
||||
|
||||
let mut buf = String::with_capacity(100);
|
||||
client_io.read_to_string(&mut buf).await?;
|
||||
Ok(buf)
|
||||
});
|
||||
(server_io, task)
|
||||
}
|
||||
|
||||
fn default_backend(addr: SocketAddr) -> client_policy::Backend {
|
||||
use client_policy::{Backend, BackendDispatcher, EndpointMetadata, Meta, Queue};
|
||||
Backend {
|
||||
meta: Meta::new_default("test"),
|
||||
queue: Queue {
|
||||
capacity: 100,
|
||||
failfast_timeout: Duration::from_secs(10),
|
||||
},
|
||||
dispatcher: BackendDispatcher::Forward(addr, EndpointMetadata::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn sni_route(backend: client_policy::Backend, sni: sni::MatchSni) -> client_policy::tls::Route {
|
||||
use client_policy::{
|
||||
tls::{Filter, Policy, Route, Rule},
|
||||
Meta, RouteBackend, RouteDistribution,
|
||||
};
|
||||
use linkerd_tls_route::r#match::MatchSession;
|
||||
use once_cell::sync::Lazy;
|
||||
static NO_FILTERS: Lazy<Arc<[Filter]>> = Lazy::new(|| Arc::new([]));
|
||||
Route {
|
||||
snis: vec![sni],
|
||||
rules: vec![Rule {
|
||||
matches: vec![MatchSession::default()],
|
||||
policy: Policy {
|
||||
meta: Meta::new_default("test_route"),
|
||||
filters: NO_FILTERS.clone(),
|
||||
params: (),
|
||||
distribution: RouteDistribution::FirstAvailable(Arc::new([RouteBackend {
|
||||
filters: NO_FILTERS.clone(),
|
||||
backend,
|
||||
}])),
|
||||
},
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
// generates a sample ClientHello TLS message for testing
|
||||
fn generate_client_hello(sni: &str) -> Vec<u8> {
|
||||
use tokio_rustls::rustls::{
|
||||
internal::msgs::{
|
||||
base::Payload,
|
||||
enums::Compression,
|
||||
handshake::{
|
||||
ClientExtension, ClientHelloPayload, HandshakeMessagePayload, HandshakePayload,
|
||||
Random, SessionId,
|
||||
},
|
||||
message::{MessagePayload, PlainMessage},
|
||||
},
|
||||
server::DnsName,
|
||||
CipherSuite, ContentType, HandshakeType, ProtocolVersion,
|
||||
};
|
||||
|
||||
let sni = DnsName::try_from(sni.to_string()).unwrap();
|
||||
|
||||
let hs_payload = HandshakeMessagePayload {
|
||||
typ: HandshakeType::ClientHello,
|
||||
payload: HandshakePayload::ClientHello(ClientHelloPayload {
|
||||
client_version: ProtocolVersion::TLSv1_2,
|
||||
random: Random::from([0; 32]),
|
||||
session_id: SessionId::empty(),
|
||||
cipher_suites: vec![CipherSuite::TLS_NULL_WITH_NULL_NULL],
|
||||
compression_methods: vec![Compression::Null],
|
||||
extensions: vec![ClientExtension::make_sni(sni.borrow())],
|
||||
}),
|
||||
};
|
||||
|
||||
let mut hs_payload_bytes = Vec::default();
|
||||
MessagePayload::handshake(hs_payload).encode(&mut hs_payload_bytes);
|
||||
|
||||
let message = PlainMessage {
|
||||
typ: ContentType::Handshake,
|
||||
version: ProtocolVersion::TLSv1_2,
|
||||
payload: Payload(hs_payload_bytes),
|
||||
};
|
||||
|
||||
message.into_unencrypted_opaque().encode()
|
||||
}
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
use super::*;
|
||||
use crate::tls::Tls;
|
||||
use linkerd_app_core::{
|
||||
svc::ServiceExt,
|
||||
tls::{NewDetectRequiredSni, ServerName},
|
||||
trace, NameAddr,
|
||||
};
|
||||
use linkerd_proxy_client_policy as client_policy;
|
||||
use std::{net::SocketAddr, str::FromStr, sync::Arc};
|
||||
use tokio::sync::watch;
|
||||
|
||||
#[tokio::test(flavor = "current_thread", start_paused = true)]
|
||||
async fn routes() {
|
||||
let _trace = trace::test::trace_init();
|
||||
|
||||
const AUTHORITY: &str = "logical.test.svc.cluster.local";
|
||||
const PORT: u16 = 666;
|
||||
let addr = SocketAddr::new([192, 0, 2, 41].into(), PORT);
|
||||
let dest: NameAddr = format!("{AUTHORITY}:{PORT}")
|
||||
.parse::<NameAddr>()
|
||||
.expect("dest addr is valid");
|
||||
let resolve = support::resolver().endpoint_exists(dest.clone(), addr, Default::default());
|
||||
let (rt, _shutdown) = runtime();
|
||||
|
||||
let client_hello = generate_client_hello(AUTHORITY);
|
||||
let (srv, io, rsp) = MockServer::new(addr, AUTHORITY, client_hello);
|
||||
|
||||
let mut connect = ConnectTcp::default();
|
||||
connect.add_server(srv);
|
||||
|
||||
let stack = Outbound::new(default_config(), rt, &mut Default::default())
|
||||
.with_stack(connect)
|
||||
.push_tls_concrete(resolve)
|
||||
.push_tls_logical()
|
||||
.map_stack(|config, _rt, stk| {
|
||||
stk.push_new_idle_cached(config.discovery_idle_timeout)
|
||||
.push_map_target(|(sni, parent): (ServerName, _)| Tls { sni, parent })
|
||||
.push(NewDetectRequiredSni::layer(Duration::from_secs(1)))
|
||||
.arc_new_clone_tcp()
|
||||
})
|
||||
.into_inner();
|
||||
|
||||
let correct_backend = default_backend(addr);
|
||||
let correct_route = sni_route(
|
||||
correct_backend.clone(),
|
||||
sni::MatchSni::Exact(AUTHORITY.into()),
|
||||
);
|
||||
|
||||
let wrong_addr = SocketAddr::new([0, 0, 0, 0].into(), PORT);
|
||||
let wrong_backend = default_backend(wrong_addr);
|
||||
let wrong_route_1 = sni_route(
|
||||
wrong_backend.clone(),
|
||||
sni::MatchSni::from_str("foo").unwrap(),
|
||||
);
|
||||
let wrong_route_2 = sni_route(
|
||||
wrong_backend.clone(),
|
||||
sni::MatchSni::from_str("*.test.svc.cluster.local").unwrap(),
|
||||
);
|
||||
|
||||
let (_route_tx, routes) = watch::channel(Routes {
|
||||
addr: addr.into(),
|
||||
backends: Arc::new([correct_backend, wrong_backend]),
|
||||
routes: Arc::new([correct_route, wrong_route_1, wrong_route_2]),
|
||||
meta: ParentRef(client_policy::Meta::new_default("parent")),
|
||||
});
|
||||
|
||||
let target = Target { num: 1, routes };
|
||||
let svc = stack.new_service(target);
|
||||
|
||||
svc.oneshot(io).await.unwrap();
|
||||
let msg = rsp.await.unwrap().unwrap();
|
||||
assert_eq!(msg, AUTHORITY);
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{IoSlice, PeerAddr, Poll};
|
||||
use crate::{IoSlice, Peek, PeerAddr, Poll};
|
||||
use futures::ready;
|
||||
use linkerd_errno::Errno;
|
||||
use pin_project::pin_project;
|
||||
|
|
@ -82,3 +82,10 @@ impl<T: PeerAddr, S> PeerAddr for SensorIo<T, S> {
|
|||
self.io.peer_addr()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<I: Peek + Send + Sync, S: Sensor + Sync> Peek for SensorIo<I, S> {
|
||||
async fn peek(&self, buf: &mut [u8]) -> Result<usize> {
|
||||
self.io.peek(buf).await
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,107 +0,0 @@
|
|||
use crate::{
|
||||
server::{detect_sni, DetectIo, Timeout},
|
||||
ServerName,
|
||||
};
|
||||
use linkerd_error::Error;
|
||||
use linkerd_io as io;
|
||||
use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Service, ServiceExt};
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::time;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Clone, Debug, Error)]
|
||||
#[error("SNI detection timed out")]
|
||||
pub struct SniDetectionTimeoutError;
|
||||
|
||||
#[derive(Clone, Debug, Error)]
|
||||
#[error("Could not find SNI")]
|
||||
pub struct NoSniFoundError;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NewDetectSni<P, N> {
|
||||
params: P,
|
||||
inner: N,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DetectSni<T, P, N> {
|
||||
target: T,
|
||||
inner: N,
|
||||
timeout: Timeout,
|
||||
params: P,
|
||||
}
|
||||
|
||||
impl<P, N> NewDetectSni<P, N> {
|
||||
pub fn new(params: P, inner: N) -> Self {
|
||||
Self { inner, params }
|
||||
}
|
||||
|
||||
pub fn layer(params: P) -> impl layer::Layer<N, Service = Self> + Clone
|
||||
where
|
||||
P: Clone,
|
||||
{
|
||||
layer::mk(move |inner| Self::new(params.clone(), inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, P, N> NewService<T> for NewDetectSni<P, N>
|
||||
where
|
||||
P: ExtractParam<Timeout, T> + Clone,
|
||||
N: Clone,
|
||||
{
|
||||
type Service = DetectSni<T, P, N>;
|
||||
|
||||
fn new_service(&self, target: T) -> Self::Service {
|
||||
let timeout = self.params.extract_param(&target);
|
||||
DetectSni {
|
||||
target,
|
||||
timeout,
|
||||
inner: self.inner.clone(),
|
||||
params: self.params.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, P, I, N, S> Service<I> for DetectSni<T, P, N>
|
||||
where
|
||||
T: Clone + Send + Sync + 'static,
|
||||
P: InsertParam<ServerName, T> + Clone + Send + Sync + 'static,
|
||||
P::Target: Send + 'static,
|
||||
I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static,
|
||||
N: NewService<P::Target, Service = S> + Clone + Send + 'static,
|
||||
S: Service<DetectIo<I>> + Send,
|
||||
S::Error: Into<Error>,
|
||||
S::Future: Send,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<S::Response, Error>> + Send + 'static>>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, io: I) -> Self::Future {
|
||||
let target = self.target.clone();
|
||||
let new_accept = self.inner.clone();
|
||||
let params = self.params.clone();
|
||||
|
||||
// Detect the SNI from a ClientHello (or timeout).
|
||||
let Timeout(timeout) = self.timeout;
|
||||
let detect = time::timeout(timeout, detect_sni(io));
|
||||
Box::pin(async move {
|
||||
let (sni, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??;
|
||||
let sni = sni.ok_or(NoSniFoundError)?;
|
||||
|
||||
debug!("detected SNI: {:?}", sni);
|
||||
let svc = new_accept.new_service(params.insert_param(sni, target));
|
||||
svc.oneshot(io).await.map_err(Into::into)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -2,12 +2,14 @@
|
|||
#![forbid(unsafe_code)]
|
||||
|
||||
pub mod client;
|
||||
pub mod detect_sni;
|
||||
pub mod server;
|
||||
|
||||
pub use self::{
|
||||
client::{Client, ClientTls, ConditionalClientTls, ConnectMeta, NoClientTls, ServerId},
|
||||
server::{ClientId, ConditionalServerTls, NewDetectTls, NoServerTls, ServerTls},
|
||||
server::{
|
||||
ClientId, ConditionalServerTls, NewDetectRequiredSni, NewDetectTls, NoServerTls,
|
||||
NoSniFoundError, ServerTls, SniDetectionTimeoutError,
|
||||
},
|
||||
};
|
||||
|
||||
use linkerd_dns_name as dns;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
mod client_hello;
|
||||
mod required_sni;
|
||||
|
||||
use crate::{NegotiatedProtocol, ServerName};
|
||||
use bytes::BytesMut;
|
||||
|
|
@ -18,6 +19,8 @@ use thiserror::Error;
|
|||
use tokio::time::{self, Duration};
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
pub use self::required_sni::{NewDetectRequiredSni, NoSniFoundError, SniDetectionTimeoutError};
|
||||
|
||||
/// Describes the authenticated identity of a remote client.
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub struct ClientId(pub id::Id);
|
||||
|
|
@ -65,6 +68,7 @@ pub struct NewDetectTls<L, P, N> {
|
|||
_local_identity: std::marker::PhantomData<fn() -> L>,
|
||||
}
|
||||
|
||||
/// A param type used to indicate the timeout after which detection should fail.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Timeout(pub Duration);
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,118 @@
|
|||
use crate::{
|
||||
server::{detect_sni, DetectIo},
|
||||
ServerName,
|
||||
};
|
||||
use linkerd_error::Error;
|
||||
use linkerd_io as io;
|
||||
use linkerd_stack::{layer, NewService, Service, ServiceExt};
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::time;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Clone, Debug, Error)]
|
||||
#[error("SNI detection timed out")]
|
||||
pub struct SniDetectionTimeoutError;
|
||||
|
||||
#[derive(Clone, Debug, Error)]
|
||||
#[error("Could not find SNI")]
|
||||
pub struct NoSniFoundError;
|
||||
|
||||
/// A NewService that instruments an inner stack with knowledge of the
|
||||
/// connection's TLS ServerName (i.e. from an SNI header).
|
||||
///
|
||||
/// This differs from the parent module's NewDetectTls in a a few ways:
|
||||
///
|
||||
/// - It requires that all connections have an SNI.
|
||||
/// - It assumes that these connections may not be terminated locally, so there
|
||||
/// is no concept of a local server name.
|
||||
/// - There are no special affordances for mutually authenticated TLS, so we
|
||||
/// make no attempt to detect the client's identity.
|
||||
/// - The detection timeout is fixed and cannot vary per target (for
|
||||
/// convenience, to reduce needless boilerplate).
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NewDetectRequiredSni<N> {
|
||||
inner: N,
|
||||
timeout: time::Duration,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DetectRequiredSni<T, N> {
|
||||
target: T,
|
||||
inner: N,
|
||||
timeout: time::Duration,
|
||||
}
|
||||
|
||||
// === impl NewDetectRequiredSni ===
|
||||
|
||||
impl<N> NewDetectRequiredSni<N> {
|
||||
fn new(timeout: time::Duration, inner: N) -> Self {
|
||||
Self { inner, timeout }
|
||||
}
|
||||
|
||||
pub fn layer(timeout: time::Duration) -> impl layer::Layer<N, Service = Self> + Clone {
|
||||
layer::mk(move |inner| Self::new(timeout, inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, N> NewService<T> for NewDetectRequiredSni<N>
|
||||
where
|
||||
N: Clone,
|
||||
{
|
||||
type Service = DetectRequiredSni<T, N>;
|
||||
|
||||
fn new_service(&self, target: T) -> Self::Service {
|
||||
DetectRequiredSni::new(self.timeout, target, self.inner.clone())
|
||||
}
|
||||
}
|
||||
|
||||
// === impl DetectRequiredSni ===
|
||||
|
||||
impl<T, N> DetectRequiredSni<T, N> {
|
||||
fn new(timeout: time::Duration, target: T, inner: N) -> Self {
|
||||
Self {
|
||||
target,
|
||||
inner,
|
||||
timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, I, N, S> Service<I> for DetectRequiredSni<T, N>
|
||||
where
|
||||
T: Clone + Send + Sync + 'static,
|
||||
I: io::AsyncRead + io::Peek + io::AsyncWrite + Send + Sync + Unpin + 'static,
|
||||
N: NewService<(ServerName, T), Service = S> + Clone + Send + 'static,
|
||||
S: Service<DetectIo<I>> + Send,
|
||||
S::Error: Into<Error>,
|
||||
S::Future: Send,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<S::Response, Error>> + Send + 'static>>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, io: I) -> Self::Future {
|
||||
let target = self.target.clone();
|
||||
let new_accept = self.inner.clone();
|
||||
|
||||
// Detect the SNI from a ClientHello (or timeout).
|
||||
let detect = time::timeout(self.timeout, detect_sni(io));
|
||||
Box::pin(async move {
|
||||
let (res, io) = detect.await.map_err(|_| SniDetectionTimeoutError)??;
|
||||
let sni = res.ok_or(NoSniFoundError)?;
|
||||
debug!(?sni, "Detected TLS");
|
||||
|
||||
let svc = new_accept.new_service((sni, target));
|
||||
svc.oneshot(io).await.map_err(Into::into)
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue