From 016d748653920014bb8b97084ae75bde626a83fe Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 20 Jun 2018 16:31:48 -0700 Subject: [PATCH] proxy: re-enabled vectored writes through our dynamic Io trait object. (#1167) This adds `Io::write_buf_erased` that doesn't required `Self: Sized`, so it can be called on trait objects. By using this method, specialized methods of `TcpStream` (and others) can use their `write_buf` to do vectored writes. Since it can be easy to forget to call `Io::write_buf_erased` instead of `Io::write_buf`, the concept of making a `Box` has been made private. A new type, `BoxedIo`, implements all the super traits of `Io`, while making the `Io` trait private to the `transport` module. Anything hoping to use a `Box` can use a `BoxedIo` instead, and know that the write buf erase dance is taken care of. Adds a test to `transport::io` checking that the dance we've done does indeed call the underlying specialized `write_buf` method. Closes #1162 --- proxy/src/connection.rs | 8 +- proxy/src/transport/io.rs | 169 ++++++++++++++++++++++++++ proxy/src/transport/mod.rs | 16 +-- proxy/src/transport/tls/connection.rs | 6 +- 4 files changed, 180 insertions(+), 19 deletions(-) create mode 100644 proxy/src/transport/io.rs diff --git a/proxy/src/connection.rs b/proxy/src/connection.rs index 698006177..affbbda9f 100644 --- a/proxy/src/connection.rs +++ b/proxy/src/connection.rs @@ -12,7 +12,7 @@ use tokio::{ use conditional::Conditional; use ctx::transport::TlsStatus; use config::Addr; -use transport::{GetOriginalDst, Io, tls}; +use transport::{AddrInfo, BoxedIo, GetOriginalDst, tls}; pub struct BoundPort { inner: std::net::TcpListener, @@ -47,7 +47,7 @@ pub enum Connecting { /// subverted. #[derive(Debug)] pub struct Connection { - io: Box, + io: BoxedIo, /// This buffer gets filled up when "peeking" bytes on this Connection. /// /// This is used instead of MSG_PEEK in order to support TLS streams. @@ -213,7 +213,7 @@ impl Future for Connecting { impl Connection { fn plain(io: TcpStream, why_no_tls: tls::ReasonForNoTls) -> Self { Connection { - io: Box::new(io), + io: BoxedIo::new(io), peek_buf: BytesMut::new(), tls_status: Conditional::None(why_no_tls), } @@ -221,7 +221,7 @@ impl Connection { fn tls(tls: tls::Connection) -> Self { Connection { - io: Box::new(tls), + io: BoxedIo::new(tls), peek_buf: BytesMut::new(), tls_status: Conditional::Some(()), } diff --git a/proxy/src/transport/io.rs b/proxy/src/transport/io.rs new file mode 100644 index 000000000..5ff700905 --- /dev/null +++ b/proxy/src/transport/io.rs @@ -0,0 +1,169 @@ +use std::io; +use std::net::{Shutdown, SocketAddr}; + +use bytes::Buf; +use futures::Poll; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::AddrInfo; +use self::internal::Io; + +/// A public wrapper around a `Box`. +/// +/// This type ensures that the proper write_buf method is called, +/// to allow vectored writes to occur. +#[derive(Debug)] +pub struct BoxedIo(Box); + +impl BoxedIo { + pub fn new(io: T) -> Self { + BoxedIo(Box::new(io)) + } + + /// Since `Io` isn't publicly exported, but `Connection` wants + /// this method, it's just an inherent method. + pub fn shutdown_write(&mut self) -> Result<(), io::Error> { + self.0.shutdown_write() + } +} + +impl io::Read for BoxedIo { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl io::Write for BoxedIo { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } +} + +impl AsyncRead for BoxedIo { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.0.prepare_uninitialized_buffer(buf) + } +} + +impl AsyncWrite for BoxedIo { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.0.shutdown() + } + + fn write_buf(&mut self, mut buf: &mut B) -> Poll { + // A trait object of AsyncWrite would use the default write_buf, + // which doesn't allow vectored writes. Going through this method + // allows the trait object to call the specialized write_buf method. + self.0.write_buf_erased(&mut buf) + } +} + +impl AddrInfo for BoxedIo { + fn local_addr(&self) -> Result { + self.0.local_addr() + } + + fn get_original_dst(&self) -> Option { + self.0.get_original_dst() + } +} + +pub(super) mod internal { + use std::io; + use tokio::net::TcpStream; + use super::{AddrInfo, AsyncRead, AsyncWrite, Buf, Poll, Shutdown}; + + /// This trait is private, since it's purpose is for creating a dynamic + /// trait object, but doing so without care can lead not getting vectored + /// writes. + /// + /// Instead, used the concrete `BoxedIo` type. + pub trait Io: AddrInfo + AsyncRead + AsyncWrite + Send { + fn shutdown_write(&mut self) -> Result<(), io::Error>; + + /// This method is to allow using `Async::write_buf` even through a + /// trait object. + fn write_buf_erased(&mut self, buf: &mut Buf) -> Poll; + } + + impl Io for TcpStream { + fn shutdown_write(&mut self) -> Result<(), io::Error> { + TcpStream::shutdown(self, Shutdown::Write) + } + + fn write_buf_erased(&mut self, mut buf: &mut Buf) -> Poll { + self.write_buf(&mut buf) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug)] + struct WriteBufDetector; + + impl io::Read for WriteBufDetector { + fn read(&mut self, _: &mut [u8]) -> io::Result { + unimplemented!() + } + } + + impl io::Write for WriteBufDetector { + fn write(&mut self, _: &[u8]) -> io::Result { + panic!("BoxedIo called wrong write_buf method"); + } + fn flush(&mut self) -> io::Result<()> { + unimplemented!() + } + } + + impl AsyncRead for WriteBufDetector {} + + impl AsyncWrite for WriteBufDetector { + fn shutdown(&mut self) -> Poll<(), io::Error> { + unimplemented!() + } + + fn write_buf(&mut self, _: &mut B) -> Poll { + Ok(0.into()) + } + } + + impl AddrInfo for WriteBufDetector { + fn local_addr(&self) -> Result { + unimplemented!() + } + + fn get_original_dst(&self) -> Option { + unimplemented!() + } + } + + impl Io for WriteBufDetector { + fn shutdown_write(&mut self) -> Result<(), io::Error> { + unimplemented!() + } + + fn write_buf_erased(&mut self, mut buf: &mut Buf) -> Poll { + self.write_buf(&mut buf) + } + } + + + #[test] + fn boxed_io_uses_vectored_io() { + use bytes::IntoBuf; + let mut io = BoxedIo::new(WriteBufDetector); + + // This method will trigger the panic in WriteBufDetector::write IFF + // BoxedIo doesn't call write_buf_erased, but write_buf, and triggering + // a regular write. + io.write_buf(&mut "hello".into_buf()).expect("write_buf"); + } +} diff --git a/proxy/src/transport/mod.rs b/proxy/src/transport/mod.rs index d56501a36..abb9243c7 100644 --- a/proxy/src/transport/mod.rs +++ b/proxy/src/transport/mod.rs @@ -1,10 +1,6 @@ -use std::io; -use std::net::Shutdown; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::net::TcpStream; - mod connect; mod addr_info; +mod io; pub mod tls; pub use self::connect::{ @@ -13,13 +9,5 @@ pub use self::connect::{ LookupAddressAndConnect, }; pub use self::addr_info::{AddrInfo, GetOriginalDst, SoOriginalDst}; +pub use self::io::BoxedIo; -pub trait Io: AddrInfo + AsyncRead + AsyncWrite + Send { - fn shutdown_write(&mut self) -> Result<(), io::Error>; -} - -impl Io for TcpStream { - fn shutdown_write(&mut self) -> Result<(), io::Error> { - TcpStream::shutdown(self, Shutdown::Write) - } -} diff --git a/proxy/src/transport/tls/connection.rs b/proxy/src/transport/tls/connection.rs index d38b760f4..4eb91676f 100644 --- a/proxy/src/transport/tls/connection.rs +++ b/proxy/src/transport/tls/connection.rs @@ -6,7 +6,7 @@ use futures::Future; use tokio::prelude::*; use tokio::net::TcpStream; -use transport::{AddrInfo, Io}; +use transport::{AddrInfo, io::internal::Io}; use super::{ identity::Identity, @@ -110,4 +110,8 @@ impl Io for Connection { fn shutdown_write(&mut self) -> Result<(), io::Error> { self.0.get_mut().0.shutdown_write() } + + fn write_buf_erased(&mut self, mut buf: &mut Buf) -> Poll { + self.0.write_buf(&mut buf) + } }