From 74b4935271d84c17f285feb8ee2ca712ca4c2ca4 Mon Sep 17 00:00:00 2001 From: katelyn martin Date: Fri, 14 Feb 2025 00:00:00 +0000 Subject: [PATCH] wip: app/core, io, meshtls, proxy/transport, tls --- Cargo.lock | 11 +++ linkerd/app/core/Cargo.toml | 1 + linkerd/app/core/src/control.rs | 12 ++- linkerd/app/core/src/errors/body.rs | 27 ++---- linkerd/io/Cargo.toml | 1 + linkerd/io/src/either.rs | 54 +++++++++++ linkerd/io/src/prefixed.rs | 55 +++++++++++ linkerd/io/src/scoped.rs | 46 +++++++++ linkerd/io/src/sensor.rs | 36 +++++++ linkerd/meshtls/Cargo.toml | 3 + linkerd/meshtls/boring/Cargo.toml | 2 + linkerd/meshtls/boring/src/client.rs | 27 ++++-- linkerd/meshtls/rustls/Cargo.toml | 2 + linkerd/meshtls/rustls/src/client.rs | 32 +++++-- linkerd/meshtls/src/client.rs | 71 ++++++++++++++ linkerd/proxy/transport/Cargo.toml | 2 + linkerd/proxy/transport/src/connect.rs | 126 ++++++++++++++++++++++++- linkerd/proxy/transport/src/lib.rs | 8 +- linkerd/tls/Cargo.toml | 1 + 19 files changed, 467 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 042505b7d..4d0eadbd0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1407,6 +1407,7 @@ dependencies = [ "http 1.2.0", "http-body", "hyper", + "hyper-util", "ipnet", "linkerd-addr", "linkerd-conditional", @@ -1994,6 +1995,7 @@ dependencies = [ "async-trait", "bytes", "futures", + "hyper", "hyper-util", "linkerd-errno", "pin-project", @@ -2007,6 +2009,8 @@ name = "linkerd-meshtls" version = "0.1.0" dependencies = [ "futures", + "hyper", + "hyper-util", "linkerd-conditional", "linkerd-dns-name", "linkerd-error", @@ -2032,6 +2036,8 @@ dependencies = [ "boring", "futures", "hex", + "hyper", + "hyper-util", "linkerd-dns-name", "linkerd-error", "linkerd-identity", @@ -2051,6 +2057,8 @@ name = "linkerd-meshtls-rustls" version = "0.1.0" dependencies = [ "futures", + "hyper", + "hyper-util", "linkerd-dns-name", "linkerd-error", "linkerd-identity", @@ -2485,6 +2493,8 @@ name = "linkerd-proxy-transport" version = "0.1.0" dependencies = [ "futures", + "hyper", + "hyper-util", "libc", "linkerd-error", "linkerd-io", @@ -2632,6 +2642,7 @@ dependencies = [ "async-trait", "bytes", "futures", + "hyper", "linkerd-conditional", "linkerd-dns-name", "linkerd-error", diff --git a/linkerd/app/core/Cargo.toml b/linkerd/app/core/Cargo.toml index 7c343a241..111457954 100644 --- a/linkerd/app/core/Cargo.toml +++ b/linkerd/app/core/Cargo.toml @@ -18,6 +18,7 @@ drain = { version = "0.1", features = ["retain"] } http = { workspace = true } http-body = { workspace = true } hyper = { workspace = true, features = ["http1", "http2"] } +hyper-util = { workspace = true } futures = { version = "0.3", default-features = false } ipnet = "2.11" prometheus-client = "0.22" diff --git a/linkerd/app/core/src/control.rs b/linkerd/app/core/src/control.rs index 92685c852..4fcb1133b 100644 --- a/linkerd/app/core/src/control.rs +++ b/linkerd/app/core/src/control.rs @@ -69,8 +69,10 @@ impl fmt::Display for ControlAddr { } } -pub type RspBody = - linkerd_http_metrics::requests::ResponseBody, classify::Eos>; +pub type RspBody = linkerd_http_metrics::requests::ResponseBody< + http::balance::Body, + classify::Eos, +>; #[derive(Clone, Debug, Default)] pub struct Metrics { @@ -112,7 +114,7 @@ impl Config { warn!(error, "Failed to resolve control-plane component"); if let Some(e) = crate::errors::cause_ref::(&*error) { if let Some(ttl) = e.negative_ttl() { - return Ok(Either::Left( + return Ok::<_, Error>(Either::Left( IntervalStream::new(time::interval(ttl)).map(|_| ()), )); } @@ -129,9 +131,9 @@ impl Config { self.connect.user_timeout, )) .push(tls::Client::layer(identity)) - .push_connect_timeout(self.connect.timeout) + .push_connect_timeout(self.connect.timeout) // Client .push_map_target(|(_version, target)| target) - .push(self::client::layer(self.connect.http2)) + .push(self::client::layer::<_, _>(self.connect.http2)) .push_on_service(svc::MapErr::layer_boxed()) .into_new_service(); diff --git a/linkerd/app/core/src/errors/body.rs b/linkerd/app/core/src/errors/body.rs index 975979838..10f23b07d 100644 --- a/linkerd/app/core/src/errors/body.rs +++ b/linkerd/app/core/src/errors/body.rs @@ -3,6 +3,7 @@ use super::{ respond::{HttpRescue, SyntheticHttpResponse}, }; use http::{header::HeaderValue, HeaderMap}; +use http_body::Frame; use linkerd_error::{Error, Result}; use pin_project::pin_project; use std::{ @@ -66,19 +67,18 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let ResponseBodyProj(inner) = self.as_mut().project(); match inner.project() { - InnerProj::Passthru(inner) => inner.poll_data(cx), - InnerProj::Rescued { trailers: _ } => Poll::Ready(None), + InnerProj::Passthru(inner) => inner.poll_frame(cx), InnerProj::GrpcRescue { inner, rescue, emit_headers, - } => match inner.poll_data(cx) { + } => match inner.poll_frame(cx) { Poll::Ready(Some(Err(error))) => { // The inner body has yielded an error, which we will try to rescue. If so, // store our synthetic trailers reporting the error. @@ -88,19 +88,10 @@ where } data => data, }, - } - } - - #[inline] - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - let ResponseBodyProj(inner) = self.project(); - match inner.project() { - InnerProj::Passthru(inner) => inner.poll_trailers(cx), - InnerProj::GrpcRescue { inner, .. } => inner.poll_trailers(cx), - InnerProj::Rescued { trailers } => Poll::Ready(Ok(trailers.take())), + InnerProj::Rescued { trailers } => { + let trailers = trailers.take().map(Frame::trailers).map(Ok); + Poll::Ready(trailers) + } } } diff --git a/linkerd/io/Cargo.toml b/linkerd/io/Cargo.toml index 41f13bab3..aabddb14b 100644 --- a/linkerd/io/Cargo.toml +++ b/linkerd/io/Cargo.toml @@ -16,6 +16,7 @@ default = [] async-trait = "0.1" futures = { version = "0.3", default-features = false } bytes = { workspace = true } +hyper = { workspace = true, default-features = false } hyper-util = { workspace = true, features = ["tokio"] } linkerd-errno = { path = "../errno" } tokio = { version = "1", features = ["io-util", "net"] } diff --git a/linkerd/io/src/either.rs b/linkerd/io/src/either.rs index 142822320..0f930f8b3 100644 --- a/linkerd/io/src/either.rs +++ b/linkerd/io/src/either.rs @@ -47,6 +47,19 @@ impl io::AsyncRead for EitherIo { } } +impl hyper::rt::Read for EitherIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> io::Poll<()> { + match self.project() { + EitherIoProj::Left(l) => l.poll_read(cx, buf), + EitherIoProj::Right(r) => r.poll_read(cx, buf), + } + } +} + impl io::AsyncWrite for EitherIo { #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { @@ -92,3 +105,44 @@ impl io::AsyncWrite for EitherIo { } } } + +impl hyper::rt::Write for EitherIo { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll { + match self.project() { + EitherIoProj::Left(l) => l.poll_write(cx, buf), + EitherIoProj::Right(r) => r.poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + match self.project() { + EitherIoProj::Left(l) => l.poll_flush(cx), + EitherIoProj::Right(r) => r.poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + match self.project() { + EitherIoProj::Left(l) => l.poll_shutdown(cx), + EitherIoProj::Right(r) => r.poll_shutdown(cx), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + EitherIo::Left(l) => l.is_write_vectored(), + EitherIo::Right(r) => r.is_write_vectored(), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> io::Poll { + match self.project() { + EitherIoProj::Left(l) => l.poll_write_vectored(cx, bufs), + EitherIoProj::Right(r) => r.poll_write_vectored(cx, bufs), + } + } +} diff --git a/linkerd/io/src/prefixed.rs b/linkerd/io/src/prefixed.rs index 27237cac1..1cee707d2 100644 --- a/linkerd/io/src/prefixed.rs +++ b/linkerd/io/src/prefixed.rs @@ -78,6 +78,35 @@ impl io::AsyncRead for PrefixedIo { } } +impl hyper::rt::Read for PrefixedIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> io::Poll<()> { + // XXX(kate): this is copy-pasted from `io::AsyncRead`, above. + let this = self.project(); + // Check the length only once, since looking as the length + // of a Bytes isn't as cheap as the length of a &[u8]. + let peeked_len = this.prefix.len(); + + if peeked_len == 0 { + this.io.poll_read(cx, buf) + } else { + let len = cmp::min(buf.remaining(), peeked_len); + buf.put_slice(&this.prefix.as_ref()[..len]); + this.prefix.advance(len); + // If we've finally emptied the prefix, drop it so we don't + // hold onto the allocated memory any longer. We won't peek + // again. + if peeked_len == len { + *this.prefix = Bytes::new(); + } + io::Poll::Ready(Ok(())) + } + } +} + impl io::Write for PrefixedIo { #[inline] fn write(&mut self, buf: &[u8]) -> io::Result { @@ -120,3 +149,29 @@ impl io::AsyncWrite for PrefixedIo { self.io.is_write_vectored() } } + +impl hyper::rt::Write for PrefixedIo { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll { + self.project().io.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + self.project().io.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + self.project().io.poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.io.is_write_vectored() + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> io::Poll { + self.project().io.poll_write_vectored(cx, bufs) + } +} diff --git a/linkerd/io/src/scoped.rs b/linkerd/io/src/scoped.rs index c52360a5f..f5ffa4af8 100644 --- a/linkerd/io/src/scoped.rs +++ b/linkerd/io/src/scoped.rs @@ -89,6 +89,17 @@ impl io::AsyncRead for ScopedIo { } } +impl hyper::rt::Read for ScopedIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> io::Poll<()> { + let this = self.project(); + this.io.poll_read(cx, buf).map_err(this.scope.err()) + } +} + impl io::Write for ScopedIo { #[inline] fn write(&mut self, buf: &[u8]) -> io::Result { @@ -138,3 +149,38 @@ impl io::AsyncWrite for ScopedIo { self.io.is_write_vectored() } } + +impl hyper::rt::Write for ScopedIo { + #[inline] + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll { + let this = self.project(); + this.io.poll_write(cx, buf).map_err(this.scope.err()) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + let this = self.project(); + this.io.poll_flush(cx).map_err(this.scope.err()) + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + let this = self.project(); + this.io.poll_shutdown(cx).map_err(this.scope.err()) + } + + fn is_write_vectored(&self) -> bool { + self.io.is_write_vectored() + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> io::Poll { + let this = self.project(); + this.io + .poll_write_vectored(cx, bufs) + .map_err(this.scope.err()) + } +} diff --git a/linkerd/io/src/sensor.rs b/linkerd/io/src/sensor.rs index 9175b2ab3..04fae1f15 100644 --- a/linkerd/io/src/sensor.rs +++ b/linkerd/io/src/sensor.rs @@ -77,6 +77,42 @@ impl AsyncWrite for SensorIo { } } +impl hyper::rt::Write for SensorIo { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let this = self.project(); + this.sensor.record_error(this.io.poll_shutdown(cx)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let this = self.project(); + this.sensor.record_error(this.io.poll_flush(cx)) + } + + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll { + let this = self.project(); + let bytes = ready!(this.sensor.record_error(this.io.poll_write(cx, buf)))?; + this.sensor.record_write(bytes); + Poll::Ready(Ok(bytes)) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll { + let this = self.project(); + let bytes = ready!(this + .sensor + .record_error(this.io.poll_write_vectored(cx, bufs)))?; + this.sensor.record_write(bytes); + Poll::Ready(Ok(bytes)) + } + + fn is_write_vectored(&self) -> bool { + self.io.is_write_vectored() + } +} + impl PeerAddr for SensorIo { fn peer_addr(&self) -> Result { self.io.peer_addr() diff --git a/linkerd/meshtls/Cargo.toml b/linkerd/meshtls/Cargo.toml index 1373c0fc1..51017b533 100644 --- a/linkerd/meshtls/Cargo.toml +++ b/linkerd/meshtls/Cargo.toml @@ -15,7 +15,10 @@ __has_any_tls_impls = [] [dependencies] futures = { version = "0.3", default-features = false } +hyper = { workspace = true } +hyper-util = { workspace = true } pin-project = "1" +tokio = { version = "1", default-features = false } linkerd-dns-name = { path = "../dns/name" } linkerd-error = { path = "../error" } diff --git a/linkerd/meshtls/boring/Cargo.toml b/linkerd/meshtls/boring/Cargo.toml index 61618e91e..054e6b40b 100644 --- a/linkerd/meshtls/boring/Cargo.toml +++ b/linkerd/meshtls/boring/Cargo.toml @@ -10,6 +10,8 @@ publish = false boring = "4" futures = { version = "0.3", default-features = false } hex = "0.4" # used for debug logging +hyper = { workspace = true } +hyper-util = { workspace = true } linkerd-error = { path = "../../error" } linkerd-dns-name = { path = "../../dns/name" } linkerd-identity = { path = "../../identity" } diff --git a/linkerd/meshtls/boring/src/client.rs b/linkerd/meshtls/boring/src/client.rs index d4f1d16cf..66743b150 100644 --- a/linkerd/meshtls/boring/src/client.rs +++ b/linkerd/meshtls/boring/src/client.rs @@ -21,7 +21,7 @@ pub struct Connect { pub type ConnectFuture = Pin>> + Send>>; #[derive(Debug)] -pub struct ClientIo(tokio_boring::SslStream); +pub struct ClientIo(hyper_util::rt::TokioIo>); // === impl NewClient === @@ -117,7 +117,7 @@ where "Initiated TLS connection" ); trace!(peer.id = %server_id, peer.name = %server_name); - Ok(ClientIo(io)) + Ok(ClientIo(hyper_util::rt::TokioIo::new(io))) }) } } @@ -131,6 +131,16 @@ impl io::AsyncRead for ClientIo { cx: &mut Context<'_>, buf: &mut io::ReadBuf<'_>, ) -> io::Poll<()> { + Pin::new(self.0.inner_mut()).poll_read(cx, buf) + } +} + +impl hyper::rt::Read for ClientIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> std::task::Poll> { Pin::new(&mut self.0).poll_read(cx, buf) } } @@ -138,17 +148,17 @@ impl io::AsyncRead for ClientIo { impl io::AsyncWrite for ClientIo { #[inline] fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { - Pin::new(&mut self.0).poll_flush(cx) + Pin::new(self.0.inner_mut()).poll_flush(cx) } #[inline] fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { - Pin::new(&mut self.0).poll_shutdown(cx) + Pin::new(self.0.inner_mut()).poll_shutdown(cx) } #[inline] fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll { - Pin::new(&mut self.0).poll_write(cx, buf) + Pin::new(self.0.inner_mut()).poll_write(cx, buf) } #[inline] @@ -157,12 +167,12 @@ impl io::AsyncWrite for ClientIo { cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> io::Poll { - Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + Pin::new(self.0.inner_mut()).poll_write_vectored(cx, bufs) } #[inline] fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() + self.0.inner().is_write_vectored() } } @@ -170,6 +180,7 @@ impl ClientIo { #[inline] pub fn negotiated_protocol(&self) -> Option> { self.0 + .inner() .ssl() .selected_alpn_protocol() .map(NegotiatedProtocolRef) @@ -179,6 +190,6 @@ impl ClientIo { impl io::PeerAddr for ClientIo { #[inline] fn peer_addr(&self) -> io::Result { - self.0.get_ref().peer_addr() + self.0.inner().get_ref().peer_addr() } } diff --git a/linkerd/meshtls/rustls/Cargo.toml b/linkerd/meshtls/rustls/Cargo.toml index b13ec1455..0be0ea7e4 100644 --- a/linkerd/meshtls/rustls/Cargo.toml +++ b/linkerd/meshtls/rustls/Cargo.toml @@ -11,6 +11,8 @@ test-util = ["linkerd-tls-test-util"] [dependencies] futures = { version = "0.3", default-features = false } +hyper = { workspace = true } +hyper-util = { workspace = true } ring = { version = "0.17", features = ["std"] } rustls-pemfile = "2.2" rustls-webpki = { version = "0.102.8", features = ["std"] } diff --git a/linkerd/meshtls/rustls/src/client.rs b/linkerd/meshtls/rustls/src/client.rs index 9856d3899..eef88b736 100644 --- a/linkerd/meshtls/rustls/src/client.rs +++ b/linkerd/meshtls/rustls/src/client.rs @@ -3,7 +3,10 @@ use linkerd_identity as id; use linkerd_io as io; use linkerd_meshtls_verifier as verifier; use linkerd_stack::{NewService, Service}; -use linkerd_tls::{client::AlpnProtocols, ClientTls, NegotiatedProtocolRef}; +use linkerd_tls::{ + client::{self, AlpnProtocols}, + ClientTls, NegotiatedProtocolRef, +}; use std::{convert::TryFrom, pin::Pin, sync::Arc, task::Context}; use tokio::sync::watch; use tokio_rustls::rustls::{self, pki_types::CertificateDer, ClientConfig}; @@ -25,7 +28,7 @@ pub struct Connect { pub type ConnectFuture = Pin>> + Send>>; #[derive(Debug)] -pub struct ClientIo(tokio_rustls::client::TlsStream); +pub struct ClientIo(hyper_util::rt::TokioIo>); // === impl NewClient === @@ -115,7 +118,7 @@ where let (_, conn) = s.get_ref(); let end_cert = extract_cert(conn)?; verifier::verify_id(end_cert, &server_id)?; - Ok(ClientIo(s)) + Ok(ClientIo(hyper_util::rt::TokioIo::new(s))) }), ) } @@ -130,6 +133,16 @@ impl io::AsyncRead for ClientIo { cx: &mut Context<'_>, buf: &mut io::ReadBuf<'_>, ) -> io::Poll<()> { + Pin::new(self.0.inner_mut()).poll_read(cx, buf) + } +} + +impl hyper::rt::Read for ClientIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> std::task::Poll> { Pin::new(&mut self.0).poll_read(cx, buf) } } @@ -137,17 +150,17 @@ impl io::AsyncRead for ClientIo { impl io::AsyncWrite for ClientIo { #[inline] fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { - Pin::new(&mut self.0).poll_flush(cx) + Pin::new(self.0.inner_mut()).poll_flush(cx) } #[inline] fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { - Pin::new(&mut self.0).poll_shutdown(cx) + Pin::new(self.0.inner_mut()).poll_shutdown(cx) } #[inline] fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll { - Pin::new(&mut self.0).poll_write(cx, buf) + Pin::new(self.0.inner_mut()).poll_write(cx, buf) } #[inline] @@ -156,12 +169,12 @@ impl io::AsyncWrite for ClientIo { cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> io::Poll { - Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + Pin::new(self.0.inner_mut()).poll_write_vectored(cx, bufs) } #[inline] fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() + self.0.inner().is_write_vectored() } } @@ -169,6 +182,7 @@ impl ClientIo { #[inline] pub fn negotiated_protocol(&self) -> Option> { self.0 + .inner() .get_ref() .1 .alpn_protocol() @@ -179,6 +193,6 @@ impl ClientIo { impl io::PeerAddr for ClientIo { #[inline] fn peer_addr(&self) -> io::Result { - self.0.get_ref().0.peer_addr() + self.0.inner().get_ref().0.peer_addr() } } diff --git a/linkerd/meshtls/src/client.rs b/linkerd/meshtls/src/client.rs index 42a80d299..3c6d78cb5 100644 --- a/linkerd/meshtls/src/client.rs +++ b/linkerd/meshtls/src/client.rs @@ -180,6 +180,23 @@ impl io::AsyncRead for ClientIo { } } +impl hyper::rt::Read for ClientIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> io::Poll<()> { + match self.project() { + #[cfg(feature = "boring")] + ClientIoProj::Boring(io) => io.poll_read(cx, buf), + #[cfg(feature = "rustls")] + ClientIoProj::Rustls(io) => io.poll_read(cx, buf), + #[cfg(not(feature = "__has_any_tls_impls"))] + _ => crate::no_tls!(cx, buf), + } + } +} + impl io::AsyncWrite for ClientIo { #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { @@ -251,6 +268,60 @@ impl io::AsyncWrite for ClientIo { } } +impl hyper::rt::Write for ClientIo { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.project() { + #[cfg(feature = "boring")] + ClientIoProj::Boring(io) => tokio::io::AsyncWrite::poll_write(io, cx, buf), + #[cfg(feature = "rustls")] + ClientIoProj::Rustls(io) => tokio::io::AsyncWrite::poll_write(io, cx, buf), + #[cfg(not(feature = "__has_any_tls_impls"))] + _ => crate::no_tls!(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + #[cfg(feature = "boring")] + ClientIoProj::Boring(io) => tokio::io::AsyncWrite::poll_flush(io, cx), + #[cfg(feature = "rustls")] + ClientIoProj::Rustls(io) => tokio::io::AsyncWrite::poll_flush(io, cx), + #[cfg(not(feature = "__has_any_tls_impls"))] + _ => crate::no_tls!(cx), + } + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.project() { + #[cfg(feature = "boring")] + ClientIoProj::Boring(io) => tokio::io::AsyncWrite::poll_shutdown(io, cx), + #[cfg(feature = "rustls")] + ClientIoProj::Rustls(io) => tokio::io::AsyncWrite::poll_shutdown(io, cx), + #[cfg(not(feature = "__has_any_tls_impls"))] + _ => crate::no_tls!(cx), + } + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(self) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self, cx, bufs) + } +} + impl io::PeerAddr for ClientIo { #[inline] fn peer_addr(&self) -> io::Result { diff --git a/linkerd/proxy/transport/Cargo.toml b/linkerd/proxy/transport/Cargo.toml index 86260e07b..7628a1865 100644 --- a/linkerd/proxy/transport/Cargo.toml +++ b/linkerd/proxy/transport/Cargo.toml @@ -11,6 +11,8 @@ Transport-level implementations that rely on core proxy infrastructure [dependencies] futures = { version = "0.3", default-features = false } +hyper = { workspace = true } +hyper-util = { workspace = true } linkerd-error = { path = "../../error" } linkerd-io = { path = "../../io" } linkerd-stack = { path = "../../stack" } diff --git a/linkerd/proxy/transport/src/connect.rs b/linkerd/proxy/transport/src/connect.rs index d11b3753c..6a261284e 100644 --- a/linkerd/proxy/transport/src/connect.rs +++ b/linkerd/proxy/transport/src/connect.rs @@ -6,7 +6,6 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tokio::net::TcpStream; use tracing::debug; #[derive(Copy, Clone, Debug)] @@ -25,7 +24,7 @@ impl ConnectTcp { } impl>> Service for ConnectTcp { - type Response = (io::ScopedIo, Local); + type Response = (io::ScopedIo, Local); type Error = io::Error; type Future = Pin> + Send + Sync + 'static>>; @@ -39,7 +38,7 @@ impl>> Service for ConnectTcp { let Remote(ServerAddr(addr)) = t.param(); debug!(server.addr = %addr, "Connecting"); Box::pin(async move { - let io = TcpStream::connect(&addr).await?; + let io = tokio::net::TcpStream::connect(&addr).await?; super::set_nodelay_or_warn(&io); let io = super::set_keepalive_or_warn(io, keepalive)?; let io = super::set_user_timeout_or_warn(io, user_timeout)?; @@ -49,7 +48,126 @@ impl>> Service for ConnectTcp { ?keepalive, "Connected", ); - Ok((io::ScopedIo::client(io), Local(ClientAddr(local_addr)))) + Ok(( + io::ScopedIo::client(self::net::TcpStream(io)), + Local(ClientAddr(local_addr)), + )) }) } } + +mod net { + use super::*; + + /// A wrapper that implements Tokio's IO traits for an inner type that + /// implements hyper's IO traits, or vice versa (implements hyper's IO + /// traits for a type that implements Tokio's IO traits). + #[derive(Debug)] + pub struct TcpStream(pub tokio::net::TcpStream); + + impl TcpStream { + fn project(self: Pin<&mut Self>) -> Pin<&mut tokio::net::TcpStream> { + let Self(stream) = self.get_mut(); + Pin::new(stream) + } + } + + impl tokio::io::AsyncRead for TcpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut linkerd_io::ReadBuf<'_>, + ) -> Poll> { + self.project().poll_read(cx, buf) + } + } + + impl tokio::io::AsyncWrite for TcpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().poll_write(cx, buf) + } + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().poll_flush(cx) + } + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().poll_shutdown(cx) + } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + self.project().poll_write_vectored(cx, bufs) + } + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + } + + impl hyper::rt::Read for TcpStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project(), cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } + } + + impl hyper::rt::Write for TcpStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write(self.project(), cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_flush(self.project(), cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + tokio::io::AsyncWrite::poll_shutdown(self.project(), cx) + } + + fn is_write_vectored(&self) -> bool { + tokio::io::AsyncWrite::is_write_vectored(&self.0) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + tokio::io::AsyncWrite::poll_write_vectored(self.project(), cx, bufs) + } + } +} diff --git a/linkerd/proxy/transport/src/lib.rs b/linkerd/proxy/transport/src/lib.rs index b23bceadf..97bb3ba65 100644 --- a/linkerd/proxy/transport/src/lib.rs +++ b/linkerd/proxy/transport/src/lib.rs @@ -2,12 +2,8 @@ //! //! Uses unsafe code to interact with socket options for SO_ORIGINAL_DST. -#![deny( - rust_2018_idioms, - clippy::disallowed_methods, - clippy::disallowed_types, - unsafe_code -)] +#![deny(rust_2018_idioms, clippy::disallowed_methods, clippy::disallowed_types)] +// diabled, temporarily: unsafe_code pub mod addrs; mod connect; diff --git a/linkerd/tls/Cargo.toml b/linkerd/tls/Cargo.toml index 6ee8a1121..33fd14d00 100644 --- a/linkerd/tls/Cargo.toml +++ b/linkerd/tls/Cargo.toml @@ -10,6 +10,7 @@ publish = false async-trait = "0.1" bytes = { workspace = true } futures = { version = "0.3", default-features = false } +hyper = { workspace = true } linkerd-conditional = { path = "../conditional" } linkerd-dns-name = { path = "../dns/name" } linkerd-error = { path = "../error" }