From 56dd9f55d8b2ef25548d2cb5902334974880cff1 Mon Sep 17 00:00:00 2001 From: katelyn martin Date: Wed, 19 Feb 2025 18:15:34 -0500 Subject: [PATCH] wip: app/core, io, meshtls, proxy/transport, tls --- Cargo.lock | 10 +++++ linkerd/app/core/Cargo.toml | 1 + linkerd/app/core/src/control.rs | 13 ++++-- linkerd/app/core/src/errors/body.rs | 27 +++++-------- linkerd/io/Cargo.toml | 1 + linkerd/io/src/either.rs | 54 +++++++++++++++++++++++++ linkerd/io/src/prefixed.rs | 55 ++++++++++++++++++++++++++ linkerd/io/src/scoped.rs | 46 +++++++++++++++++++++ linkerd/io/src/sensor.rs | 36 +++++++++++++++++ linkerd/meshtls/Cargo.toml | 2 + linkerd/meshtls/boring/Cargo.toml | 2 + linkerd/meshtls/boring/src/client.rs | 27 +++++++++---- linkerd/meshtls/rustls/Cargo.toml | 2 + linkerd/meshtls/rustls/src/client.rs | 32 ++++++++++----- linkerd/meshtls/src/client.rs | 50 +++++++++++++++++++++++ linkerd/proxy/transport/Cargo.toml | 1 + linkerd/proxy/transport/src/connect.rs | 12 ++++-- linkerd/tls/Cargo.toml | 1 + 18 files changed, 329 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cce6db25c..c70baa39c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1408,6 +1408,7 @@ dependencies = [ "http 1.2.0", "http-body", "hyper", + "hyper-util", "ipnet", "linkerd-addr", "linkerd-conditional", @@ -1995,6 +1996,7 @@ dependencies = [ "async-trait", "bytes", "futures", + "hyper", "hyper-util", "linkerd-errno", "pin-project", @@ -2008,6 +2010,8 @@ name = "linkerd-meshtls" version = "0.1.0" dependencies = [ "futures", + "hyper", + "hyper-util", "linkerd-conditional", "linkerd-dns-name", "linkerd-error", @@ -2033,6 +2037,8 @@ dependencies = [ "boring", "futures", "hex", + "hyper", + "hyper-util", "linkerd-dns-name", "linkerd-error", "linkerd-identity", @@ -2052,6 +2058,8 @@ name = "linkerd-meshtls-rustls" version = "0.1.0" dependencies = [ "futures", + "hyper", + "hyper-util", "linkerd-dns-name", "linkerd-error", "linkerd-identity", @@ -2486,6 +2494,7 @@ name = "linkerd-proxy-transport" version = "0.1.0" dependencies = [ "futures", + "hyper-util", "libc", "linkerd-error", "linkerd-io", @@ -2633,6 +2642,7 @@ dependencies = [ "async-trait", "bytes", "futures", + "hyper", "linkerd-conditional", "linkerd-dns-name", "linkerd-error", diff --git a/linkerd/app/core/Cargo.toml b/linkerd/app/core/Cargo.toml index 7c343a241..111457954 100644 --- a/linkerd/app/core/Cargo.toml +++ b/linkerd/app/core/Cargo.toml @@ -18,6 +18,7 @@ drain = { version = "0.1", features = ["retain"] } http = { workspace = true } http-body = { workspace = true } hyper = { workspace = true, features = ["http1", "http2"] } +hyper-util = { workspace = true } futures = { version = "0.3", default-features = false } ipnet = "2.11" prometheus-client = "0.22" diff --git a/linkerd/app/core/src/control.rs b/linkerd/app/core/src/control.rs index 92685c852..6ccb1e975 100644 --- a/linkerd/app/core/src/control.rs +++ b/linkerd/app/core/src/control.rs @@ -69,8 +69,10 @@ impl fmt::Display for ControlAddr { } } -pub type RspBody = - linkerd_http_metrics::requests::ResponseBody, classify::Eos>; +pub type RspBody = linkerd_http_metrics::requests::ResponseBody< + http::balance::Body, + classify::Eos, +>; #[derive(Clone, Debug, Default)] pub struct Metrics { @@ -129,9 +131,9 @@ impl Config { self.connect.user_timeout, )) .push(tls::Client::layer(identity)) - .push_connect_timeout(self.connect.timeout) + .push_connect_timeout(self.connect.timeout) // Client .push_map_target(|(_version, target)| target) - .push(self::client::layer(self.connect.http2)) + .push(self::client::layer::<_, _>(self.connect.http2)) .push_on_service(svc::MapErr::layer_boxed()) .into_new_service(); @@ -147,6 +149,8 @@ impl Config { .push_new_reconnect(self.connect.backoff) .instrument(|t: &self::client::Target| info_span!("endpoint", addr = %t.addr)); + todo!(); + /* let balance = endpoint .lift_new() .push(self::balance::layer(metrics.balance, dns, resolve_backoff)) @@ -161,6 +165,7 @@ impl Config { .push_on_service(svc::BoxCloneSyncService::layer()) .push(svc::ArcNewService::layer()) .into_inner() + */ } } diff --git a/linkerd/app/core/src/errors/body.rs b/linkerd/app/core/src/errors/body.rs index 975979838..10f23b07d 100644 --- a/linkerd/app/core/src/errors/body.rs +++ b/linkerd/app/core/src/errors/body.rs @@ -3,6 +3,7 @@ use super::{ respond::{HttpRescue, SyntheticHttpResponse}, }; use http::{header::HeaderValue, HeaderMap}; +use http_body::Frame; use linkerd_error::{Error, Result}; use pin_project::pin_project; use std::{ @@ -66,19 +67,18 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let ResponseBodyProj(inner) = self.as_mut().project(); match inner.project() { - InnerProj::Passthru(inner) => inner.poll_data(cx), - InnerProj::Rescued { trailers: _ } => Poll::Ready(None), + InnerProj::Passthru(inner) => inner.poll_frame(cx), InnerProj::GrpcRescue { inner, rescue, emit_headers, - } => match inner.poll_data(cx) { + } => match inner.poll_frame(cx) { Poll::Ready(Some(Err(error))) => { // The inner body has yielded an error, which we will try to rescue. If so, // store our synthetic trailers reporting the error. @@ -88,19 +88,10 @@ where } data => data, }, - } - } - - #[inline] - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - let ResponseBodyProj(inner) = self.project(); - match inner.project() { - InnerProj::Passthru(inner) => inner.poll_trailers(cx), - InnerProj::GrpcRescue { inner, .. } => inner.poll_trailers(cx), - InnerProj::Rescued { trailers } => Poll::Ready(Ok(trailers.take())), + InnerProj::Rescued { trailers } => { + let trailers = trailers.take().map(Frame::trailers).map(Ok); + Poll::Ready(trailers) + } } } diff --git a/linkerd/io/Cargo.toml b/linkerd/io/Cargo.toml index 41f13bab3..aabddb14b 100644 --- a/linkerd/io/Cargo.toml +++ b/linkerd/io/Cargo.toml @@ -16,6 +16,7 @@ default = [] async-trait = "0.1" futures = { version = "0.3", default-features = false } bytes = { workspace = true } +hyper = { workspace = true, default-features = false } hyper-util = { workspace = true, features = ["tokio"] } linkerd-errno = { path = "../errno" } tokio = { version = "1", features = ["io-util", "net"] } diff --git a/linkerd/io/src/either.rs b/linkerd/io/src/either.rs index 142822320..0f930f8b3 100644 --- a/linkerd/io/src/either.rs +++ b/linkerd/io/src/either.rs @@ -47,6 +47,19 @@ impl io::AsyncRead for EitherIo { } } +impl hyper::rt::Read for EitherIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> io::Poll<()> { + match self.project() { + EitherIoProj::Left(l) => l.poll_read(cx, buf), + EitherIoProj::Right(r) => r.poll_read(cx, buf), + } + } +} + impl io::AsyncWrite for EitherIo { #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { @@ -92,3 +105,44 @@ impl io::AsyncWrite for EitherIo { } } } + +impl hyper::rt::Write for EitherIo { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll { + match self.project() { + EitherIoProj::Left(l) => l.poll_write(cx, buf), + EitherIoProj::Right(r) => r.poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + match self.project() { + EitherIoProj::Left(l) => l.poll_flush(cx), + EitherIoProj::Right(r) => r.poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + match self.project() { + EitherIoProj::Left(l) => l.poll_shutdown(cx), + EitherIoProj::Right(r) => r.poll_shutdown(cx), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + EitherIo::Left(l) => l.is_write_vectored(), + EitherIo::Right(r) => r.is_write_vectored(), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> io::Poll { + match self.project() { + EitherIoProj::Left(l) => l.poll_write_vectored(cx, bufs), + EitherIoProj::Right(r) => r.poll_write_vectored(cx, bufs), + } + } +} diff --git a/linkerd/io/src/prefixed.rs b/linkerd/io/src/prefixed.rs index 27237cac1..1cee707d2 100644 --- a/linkerd/io/src/prefixed.rs +++ b/linkerd/io/src/prefixed.rs @@ -78,6 +78,35 @@ impl io::AsyncRead for PrefixedIo { } } +impl hyper::rt::Read for PrefixedIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> io::Poll<()> { + // XXX(kate): this is copy-pasted from `io::AsyncRead`, above. + let this = self.project(); + // Check the length only once, since looking as the length + // of a Bytes isn't as cheap as the length of a &[u8]. + let peeked_len = this.prefix.len(); + + if peeked_len == 0 { + this.io.poll_read(cx, buf) + } else { + let len = cmp::min(buf.remaining(), peeked_len); + buf.put_slice(&this.prefix.as_ref()[..len]); + this.prefix.advance(len); + // If we've finally emptied the prefix, drop it so we don't + // hold onto the allocated memory any longer. We won't peek + // again. + if peeked_len == len { + *this.prefix = Bytes::new(); + } + io::Poll::Ready(Ok(())) + } + } +} + impl io::Write for PrefixedIo { #[inline] fn write(&mut self, buf: &[u8]) -> io::Result { @@ -120,3 +149,29 @@ impl io::AsyncWrite for PrefixedIo { self.io.is_write_vectored() } } + +impl hyper::rt::Write for PrefixedIo { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll { + self.project().io.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + self.project().io.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + self.project().io.poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.io.is_write_vectored() + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> io::Poll { + self.project().io.poll_write_vectored(cx, bufs) + } +} diff --git a/linkerd/io/src/scoped.rs b/linkerd/io/src/scoped.rs index c52360a5f..f5ffa4af8 100644 --- a/linkerd/io/src/scoped.rs +++ b/linkerd/io/src/scoped.rs @@ -89,6 +89,17 @@ impl io::AsyncRead for ScopedIo { } } +impl hyper::rt::Read for ScopedIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> io::Poll<()> { + let this = self.project(); + this.io.poll_read(cx, buf).map_err(this.scope.err()) + } +} + impl io::Write for ScopedIo { #[inline] fn write(&mut self, buf: &[u8]) -> io::Result { @@ -138,3 +149,38 @@ impl io::AsyncWrite for ScopedIo { self.io.is_write_vectored() } } + +impl hyper::rt::Write for ScopedIo { + #[inline] + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll { + let this = self.project(); + this.io.poll_write(cx, buf).map_err(this.scope.err()) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + let this = self.project(); + this.io.poll_flush(cx).map_err(this.scope.err()) + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { + let this = self.project(); + this.io.poll_shutdown(cx).map_err(this.scope.err()) + } + + fn is_write_vectored(&self) -> bool { + self.io.is_write_vectored() + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> io::Poll { + let this = self.project(); + this.io + .poll_write_vectored(cx, bufs) + .map_err(this.scope.err()) + } +} diff --git a/linkerd/io/src/sensor.rs b/linkerd/io/src/sensor.rs index 9175b2ab3..04fae1f15 100644 --- a/linkerd/io/src/sensor.rs +++ b/linkerd/io/src/sensor.rs @@ -77,6 +77,42 @@ impl AsyncWrite for SensorIo { } } +impl hyper::rt::Write for SensorIo { + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let this = self.project(); + this.sensor.record_error(this.io.poll_shutdown(cx)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let this = self.project(); + this.sensor.record_error(this.io.poll_flush(cx)) + } + + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll { + let this = self.project(); + let bytes = ready!(this.sensor.record_error(this.io.poll_write(cx, buf)))?; + this.sensor.record_write(bytes); + Poll::Ready(Ok(bytes)) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll { + let this = self.project(); + let bytes = ready!(this + .sensor + .record_error(this.io.poll_write_vectored(cx, bufs)))?; + this.sensor.record_write(bytes); + Poll::Ready(Ok(bytes)) + } + + fn is_write_vectored(&self) -> bool { + self.io.is_write_vectored() + } +} + impl PeerAddr for SensorIo { fn peer_addr(&self) -> Result { self.io.peer_addr() diff --git a/linkerd/meshtls/Cargo.toml b/linkerd/meshtls/Cargo.toml index 1373c0fc1..c84df4261 100644 --- a/linkerd/meshtls/Cargo.toml +++ b/linkerd/meshtls/Cargo.toml @@ -15,6 +15,8 @@ __has_any_tls_impls = [] [dependencies] futures = { version = "0.3", default-features = false } +hyper = { workspace = true } +hyper-util = { workspace = true } pin-project = "1" linkerd-dns-name = { path = "../dns/name" } diff --git a/linkerd/meshtls/boring/Cargo.toml b/linkerd/meshtls/boring/Cargo.toml index 61618e91e..054e6b40b 100644 --- a/linkerd/meshtls/boring/Cargo.toml +++ b/linkerd/meshtls/boring/Cargo.toml @@ -10,6 +10,8 @@ publish = false boring = "4" futures = { version = "0.3", default-features = false } hex = "0.4" # used for debug logging +hyper = { workspace = true } +hyper-util = { workspace = true } linkerd-error = { path = "../../error" } linkerd-dns-name = { path = "../../dns/name" } linkerd-identity = { path = "../../identity" } diff --git a/linkerd/meshtls/boring/src/client.rs b/linkerd/meshtls/boring/src/client.rs index d4f1d16cf..66743b150 100644 --- a/linkerd/meshtls/boring/src/client.rs +++ b/linkerd/meshtls/boring/src/client.rs @@ -21,7 +21,7 @@ pub struct Connect { pub type ConnectFuture = Pin>> + Send>>; #[derive(Debug)] -pub struct ClientIo(tokio_boring::SslStream); +pub struct ClientIo(hyper_util::rt::TokioIo>); // === impl NewClient === @@ -117,7 +117,7 @@ where "Initiated TLS connection" ); trace!(peer.id = %server_id, peer.name = %server_name); - Ok(ClientIo(io)) + Ok(ClientIo(hyper_util::rt::TokioIo::new(io))) }) } } @@ -131,6 +131,16 @@ impl io::AsyncRead for ClientIo { cx: &mut Context<'_>, buf: &mut io::ReadBuf<'_>, ) -> io::Poll<()> { + Pin::new(self.0.inner_mut()).poll_read(cx, buf) + } +} + +impl hyper::rt::Read for ClientIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> std::task::Poll> { Pin::new(&mut self.0).poll_read(cx, buf) } } @@ -138,17 +148,17 @@ impl io::AsyncRead for ClientIo { impl io::AsyncWrite for ClientIo { #[inline] fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { - Pin::new(&mut self.0).poll_flush(cx) + Pin::new(self.0.inner_mut()).poll_flush(cx) } #[inline] fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { - Pin::new(&mut self.0).poll_shutdown(cx) + Pin::new(self.0.inner_mut()).poll_shutdown(cx) } #[inline] fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll { - Pin::new(&mut self.0).poll_write(cx, buf) + Pin::new(self.0.inner_mut()).poll_write(cx, buf) } #[inline] @@ -157,12 +167,12 @@ impl io::AsyncWrite for ClientIo { cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> io::Poll { - Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + Pin::new(self.0.inner_mut()).poll_write_vectored(cx, bufs) } #[inline] fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() + self.0.inner().is_write_vectored() } } @@ -170,6 +180,7 @@ impl ClientIo { #[inline] pub fn negotiated_protocol(&self) -> Option> { self.0 + .inner() .ssl() .selected_alpn_protocol() .map(NegotiatedProtocolRef) @@ -179,6 +190,6 @@ impl ClientIo { impl io::PeerAddr for ClientIo { #[inline] fn peer_addr(&self) -> io::Result { - self.0.get_ref().peer_addr() + self.0.inner().get_ref().peer_addr() } } diff --git a/linkerd/meshtls/rustls/Cargo.toml b/linkerd/meshtls/rustls/Cargo.toml index b13ec1455..0be0ea7e4 100644 --- a/linkerd/meshtls/rustls/Cargo.toml +++ b/linkerd/meshtls/rustls/Cargo.toml @@ -11,6 +11,8 @@ test-util = ["linkerd-tls-test-util"] [dependencies] futures = { version = "0.3", default-features = false } +hyper = { workspace = true } +hyper-util = { workspace = true } ring = { version = "0.17", features = ["std"] } rustls-pemfile = "2.2" rustls-webpki = { version = "0.102.8", features = ["std"] } diff --git a/linkerd/meshtls/rustls/src/client.rs b/linkerd/meshtls/rustls/src/client.rs index 9856d3899..eef88b736 100644 --- a/linkerd/meshtls/rustls/src/client.rs +++ b/linkerd/meshtls/rustls/src/client.rs @@ -3,7 +3,10 @@ use linkerd_identity as id; use linkerd_io as io; use linkerd_meshtls_verifier as verifier; use linkerd_stack::{NewService, Service}; -use linkerd_tls::{client::AlpnProtocols, ClientTls, NegotiatedProtocolRef}; +use linkerd_tls::{ + client::{self, AlpnProtocols}, + ClientTls, NegotiatedProtocolRef, +}; use std::{convert::TryFrom, pin::Pin, sync::Arc, task::Context}; use tokio::sync::watch; use tokio_rustls::rustls::{self, pki_types::CertificateDer, ClientConfig}; @@ -25,7 +28,7 @@ pub struct Connect { pub type ConnectFuture = Pin>> + Send>>; #[derive(Debug)] -pub struct ClientIo(tokio_rustls::client::TlsStream); +pub struct ClientIo(hyper_util::rt::TokioIo>); // === impl NewClient === @@ -115,7 +118,7 @@ where let (_, conn) = s.get_ref(); let end_cert = extract_cert(conn)?; verifier::verify_id(end_cert, &server_id)?; - Ok(ClientIo(s)) + Ok(ClientIo(hyper_util::rt::TokioIo::new(s))) }), ) } @@ -130,6 +133,16 @@ impl io::AsyncRead for ClientIo { cx: &mut Context<'_>, buf: &mut io::ReadBuf<'_>, ) -> io::Poll<()> { + Pin::new(self.0.inner_mut()).poll_read(cx, buf) + } +} + +impl hyper::rt::Read for ClientIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> std::task::Poll> { Pin::new(&mut self.0).poll_read(cx, buf) } } @@ -137,17 +150,17 @@ impl io::AsyncRead for ClientIo { impl io::AsyncWrite for ClientIo { #[inline] fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { - Pin::new(&mut self.0).poll_flush(cx) + Pin::new(self.0.inner_mut()).poll_flush(cx) } #[inline] fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { - Pin::new(&mut self.0).poll_shutdown(cx) + Pin::new(self.0.inner_mut()).poll_shutdown(cx) } #[inline] fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll { - Pin::new(&mut self.0).poll_write(cx, buf) + Pin::new(self.0.inner_mut()).poll_write(cx, buf) } #[inline] @@ -156,12 +169,12 @@ impl io::AsyncWrite for ClientIo { cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> io::Poll { - Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + Pin::new(self.0.inner_mut()).poll_write_vectored(cx, bufs) } #[inline] fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() + self.0.inner().is_write_vectored() } } @@ -169,6 +182,7 @@ impl ClientIo { #[inline] pub fn negotiated_protocol(&self) -> Option> { self.0 + .inner() .get_ref() .1 .alpn_protocol() @@ -179,6 +193,6 @@ impl ClientIo { impl io::PeerAddr for ClientIo { #[inline] fn peer_addr(&self) -> io::Result { - self.0.get_ref().0.peer_addr() + self.0.inner().get_ref().0.peer_addr() } } diff --git a/linkerd/meshtls/src/client.rs b/linkerd/meshtls/src/client.rs index 42a80d299..166e582c3 100644 --- a/linkerd/meshtls/src/client.rs +++ b/linkerd/meshtls/src/client.rs @@ -180,6 +180,23 @@ impl io::AsyncRead for ClientIo { } } +impl hyper::rt::Read for ClientIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: hyper::rt::ReadBufCursor<'_>, + ) -> io::Poll<()> { + match self.project() { + #[cfg(feature = "boring")] + ClientIoProj::Boring(io) => io.poll_read(cx, buf), + #[cfg(feature = "rustls")] + ClientIoProj::Rustls(io) => io.poll_read(cx, buf), + #[cfg(not(feature = "__has_any_tls_impls"))] + _ => crate::no_tls!(cx, buf), + } + } +} + impl io::AsyncWrite for ClientIo { #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { @@ -251,6 +268,39 @@ impl io::AsyncWrite for ClientIo { } } +impl hyper::rt::Write for ClientIo { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + todo!() + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + todo!() + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + todo!() + } + + fn is_write_vectored(&self) -> bool { + todo!() + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + todo!() + } +} + impl io::PeerAddr for ClientIo { #[inline] fn peer_addr(&self) -> io::Result { diff --git a/linkerd/proxy/transport/Cargo.toml b/linkerd/proxy/transport/Cargo.toml index 86260e07b..794a798b0 100644 --- a/linkerd/proxy/transport/Cargo.toml +++ b/linkerd/proxy/transport/Cargo.toml @@ -11,6 +11,7 @@ Transport-level implementations that rely on core proxy infrastructure [dependencies] futures = { version = "0.3", default-features = false } +hyper-util = { workspace = true } linkerd-error = { path = "../../error" } linkerd-io = { path = "../../io" } linkerd-stack = { path = "../../stack" } diff --git a/linkerd/proxy/transport/src/connect.rs b/linkerd/proxy/transport/src/connect.rs index d11b3753c..c57f4d064 100644 --- a/linkerd/proxy/transport/src/connect.rs +++ b/linkerd/proxy/transport/src/connect.rs @@ -6,7 +6,6 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tokio::net::TcpStream; use tracing::debug; #[derive(Copy, Clone, Debug)] @@ -15,6 +14,8 @@ pub struct ConnectTcp { user_timeout: UserTimeout, } +type TcpStream = io::ScopedIo>; + impl ConnectTcp { pub fn new(keepalive: Keepalive, user_timeout: UserTimeout) -> Self { Self { @@ -25,7 +26,7 @@ impl ConnectTcp { } impl>> Service for ConnectTcp { - type Response = (io::ScopedIo, Local); + type Response = (TcpStream, Local); type Error = io::Error; type Future = Pin> + Send + Sync + 'static>>; @@ -39,7 +40,7 @@ impl>> Service for ConnectTcp { let Remote(ServerAddr(addr)) = t.param(); debug!(server.addr = %addr, "Connecting"); Box::pin(async move { - let io = TcpStream::connect(&addr).await?; + let io = tokio::net::TcpStream::connect(&addr).await?; super::set_nodelay_or_warn(&io); let io = super::set_keepalive_or_warn(io, keepalive)?; let io = super::set_user_timeout_or_warn(io, user_timeout)?; @@ -49,7 +50,10 @@ impl>> Service for ConnectTcp { ?keepalive, "Connected", ); - Ok((io::ScopedIo::client(io), Local(ClientAddr(local_addr)))) + Ok(( + io::ScopedIo::client(hyper_util::rt::TokioIo::new(io)), + Local(ClientAddr(local_addr)), + )) }) } } diff --git a/linkerd/tls/Cargo.toml b/linkerd/tls/Cargo.toml index 6ee8a1121..33fd14d00 100644 --- a/linkerd/tls/Cargo.toml +++ b/linkerd/tls/Cargo.toml @@ -10,6 +10,7 @@ publish = false async-trait = "0.1" bytes = { workspace = true } futures = { version = "0.3", default-features = false } +hyper = { workspace = true } linkerd-conditional = { path = "../conditional" } linkerd-dns-name = { path = "../dns/name" } linkerd-error = { path = "../error" }