From 79a38327d2312a4b2f1d9073af58f183ae2e666e Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Sat, 26 May 2018 10:04:31 -1000 Subject: [PATCH] Abstract I/O interface into a trait. (#1020) * Rename so_original_dst.rs to addr_info.rs. Prepare for expanding the functionality of this module by renaming it. Signed-off-by: Brian Smith * Abstract I/O interface into a trait. Instead of pattern matching over an `Io` variant, use a `Box` to abstract the I/O interface. This will make it easier to add a TLS transport. Signed-off-by: Brian Smith --- proxy/src/connection.rs | 72 ++++++------------- proxy/src/lib.rs | 2 +- .../{so_original_dst.rs => addr_info.rs} | 58 ++++++++++----- proxy/src/transport/mod.rs | 19 ++++- proxy/tests/support/proxy.rs | 2 +- 5 files changed, 81 insertions(+), 72 deletions(-) rename proxy/src/transport/{so_original_dst.rs => addr_info.rs} (79%) diff --git a/proxy/src/connection.rs b/proxy/src/connection.rs index c05c323fc..2faa10b6a 100644 --- a/proxy/src/connection.rs +++ b/proxy/src/connection.rs @@ -12,6 +12,7 @@ use tokio::{ use config::Addr; use transport::GetOriginalDst; +use transport::Io; pub type PlaintextSocket = TcpStream; @@ -36,7 +37,7 @@ pub struct Connecting(ConnectFuture); /// subverted. #[derive(Debug)] pub struct Connection { - io: Io, + io: Box, /// This buffer gets filled up when "peeking" bytes on this Connection. /// /// This is used instead of MSG_PEEK in order to support TLS streams. @@ -46,11 +47,6 @@ pub struct Connection { peek_buf: BytesMut, } -#[derive(Debug)] -enum Io { - Plain(PlaintextSocket), -} - /// A trait describing that a type can peek bytes. pub trait Peek { /// An async attempt to peek bytes of this type without consuming. @@ -154,29 +150,17 @@ impl Connection { /// A constructor of `Connection` with a plain text TCP socket. pub fn plain(socket: PlaintextSocket) -> Self { Connection { - io: Io::Plain(socket), + io: Box::new(socket), peek_buf: BytesMut::new(), } } pub fn original_dst_addr(&self, get: &T) -> Option { - get.get_original_dst(self.socket()) + get.get_original_dst(&self.io) } pub fn local_addr(&self) -> Result { - self.socket().local_addr() - } - - // This must never be made public so that in the future `Connection` can - // control access to the plaintext socket for TLS, to ensure no private - // data is accidentally writen to the socket and to ensure no unprotected - // data is read from the socket. Each piece of information needed about the - // underlying socket should be exposed by its own minimal accessor function - // as is done above. - fn socket(&self) -> &PlaintextSocket { - match self.io { - Io::Plain(ref socket) => socket - } + self.io.local_addr() } } @@ -187,9 +171,7 @@ impl io::Read for Connection { let peeked_len = self.peek_buf.len(); if peeked_len == 0 { - match self.io { - Io::Plain(ref mut t) => t.read(buf), - } + self.io.read(buf) } else { let len = cmp::min(buf.len(), peeked_len); buf[..len].copy_from_slice(&self.peek_buf.as_ref()[..len]); @@ -207,47 +189,35 @@ impl io::Read for Connection { impl AsyncRead for Connection { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - match self.io { - Io::Plain(ref t) => t.prepare_uninitialized_buffer(buf), - } + self.io.prepare_uninitialized_buffer(buf) } } impl io::Write for Connection { fn write(&mut self, buf: &[u8]) -> io::Result { - match self.io { - Io::Plain(ref mut t) => t.write(buf), - } + self.io.write(buf) } fn flush(&mut self) -> io::Result<()> { - match self.io { - Io::Plain(ref mut t) => t.flush(), - } + self.io.flush() } } impl AsyncWrite for Connection { fn shutdown(&mut self) -> Poll<(), io::Error> { - use std::net::Shutdown; - match self.io { - Io::Plain(ref mut t) => { - try_ready!(AsyncWrite::shutdown(t)); - // TCP shutdown the write side. - // - // If we're shutting down, then we definitely won't write - // anymore. So, we should tell the remote about this. This - // is relied upon in our TCP proxy, to start shutting down - // the pipe if one side closes. - TcpStream::shutdown(t, Shutdown::Write).map(Async::Ready) - }, - } + try_ready!(AsyncWrite::shutdown(&mut self.io)); + + // TCP shutdown the write side. + // + // If we're shutting down, then we definitely won't write + // anymore. So, we should tell the remote about this. This + // is relied upon in our TCP proxy, to start shutting down + // the pipe if one side closes. + self.io.shutdown_write().map(Async::Ready) } fn write_buf(&mut self, buf: &mut B) -> Poll { - match self.io { - Io::Plain(ref mut t) => t.write_buf(buf), - } + self.io.write_buf(buf) } } @@ -255,9 +225,7 @@ impl Peek for Connection { fn poll_peek(&mut self) -> Poll { if self.peek_buf.is_empty() { self.peek_buf.reserve(8192); - match self.io { - Io::Plain(ref mut t) => t.read_buf(&mut self.peek_buf), - } + self.io.read_buf(&mut self.peek_buf) } else { Ok(Async::Ready(self.peek_buf.len())) } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 23a29e2ea..503026698 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -87,7 +87,7 @@ use inbound::Inbound; use map_err::MapErr; use task::MainRuntime; use transparency::{HttpBody, Server}; -pub use transport::{GetOriginalDst, SoOriginalDst}; +pub use transport::{AddrInfo, GetOriginalDst, SoOriginalDst}; use outbound::Outbound; /// Runs a sidecar proxy. diff --git a/proxy/src/transport/so_original_dst.rs b/proxy/src/transport/addr_info.rs similarity index 79% rename from proxy/src/transport/so_original_dst.rs rename to proxy/src/transport/addr_info.rs index 23fad1331..2729965dc 100644 --- a/proxy/src/transport/so_original_dst.rs +++ b/proxy/src/transport/addr_info.rs @@ -1,33 +1,59 @@ use std::net::SocketAddr; use tokio::net::TcpStream; +use std::fmt::Debug; +use std::io; + +pub trait AddrInfo: Debug { + fn local_addr(&self) -> Result; + fn get_original_dst(&self) -> Option; +} + +impl AddrInfo for Box { + fn local_addr(&self) -> Result { + self.as_ref().local_addr() + } + + fn get_original_dst(&self) -> Option { + self.as_ref().get_original_dst() + } +} + +impl AddrInfo for TcpStream { + fn local_addr(&self) -> Result { + TcpStream::local_addr(&self) + } + + #[cfg(target_os = "linux")] + fn get_original_dst(&self) -> Option { + use self::linux; + use std::os::unix::io::AsRawFd; + + let fd = self.as_raw_fd(); + let r = unsafe { linux::so_original_dst(fd) }; + r.ok() + } + + #[cfg(not(target_os = "linux"))] + fn get_original_dst(&self) -> Option { + debug!("no support for SO_ORIGINAL_DST"); + None + } +} /// A generic way to get the original destination address of a socket. /// /// This is especially useful to allow tests to provide a mock implementation. pub trait GetOriginalDst { - fn get_original_dst(&self, socket: &TcpStream) -> Option; + fn get_original_dst(&self, socket: &AddrInfo) -> Option; } #[derive(Copy, Clone, Debug)] pub struct SoOriginalDst; impl GetOriginalDst for SoOriginalDst { - #[cfg(not(target_os = "linux"))] - fn get_original_dst(&self, _: &TcpStream) -> Option { - debug!("no support for SO_ORIGINAL_DST"); - None - } - - // TODO change/remove once https://github.com/tokio-rs/tokio/issues/25 is addressed - #[cfg(target_os = "linux")] - fn get_original_dst(&self, sock: &TcpStream) -> Option { - use self::linux; - use std::os::unix::io::AsRawFd; - + fn get_original_dst(&self, sock: &AddrInfo) -> Option { trace!("get_original_dst {:?}", sock); - - let res = unsafe { linux::so_original_dst(sock.as_raw_fd()) }; - res.ok() + sock.get_original_dst() } } diff --git a/proxy/src/transport/mod.rs b/proxy/src/transport/mod.rs index 7f75bedf3..74af0a3a9 100644 --- a/proxy/src/transport/mod.rs +++ b/proxy/src/transport/mod.rs @@ -1,9 +1,24 @@ +use std::io; +use std::net::Shutdown; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; + mod connect; -mod so_original_dst; +mod addr_info; pub use self::connect::{ Connect, DnsNameAndPort, Host, HostAndPort, HostAndPortError, LookupAddressAndConnect, }; -pub use self::so_original_dst::{GetOriginalDst, SoOriginalDst}; +pub use self::addr_info::{AddrInfo, GetOriginalDst, SoOriginalDst}; + +pub trait Io: AddrInfo + AsyncRead + AsyncWrite + Send { + fn shutdown_write(&mut self) -> Result<(), io::Error>; +} + +impl Io for TcpStream { + fn shutdown_write(&mut self) -> Result<(), io::Error> { + TcpStream::shutdown(self, Shutdown::Write) + } +} diff --git a/proxy/tests/support/proxy.rs b/proxy/tests/support/proxy.rs index 142b311f9..0dcde25df 100644 --- a/proxy/tests/support/proxy.rs +++ b/proxy/tests/support/proxy.rs @@ -105,7 +105,7 @@ struct DstInner { } impl conduit_proxy::GetOriginalDst for MockOriginalDst { - fn get_original_dst(&self, sock: &TcpStream) -> Option { + fn get_original_dst(&self, sock: &AddrInfo) -> Option { sock.local_addr() .ok() .and_then(|local| {