Abstract I/O interface into a trait. (#1020)

* Rename so_original_dst.rs to addr_info.rs.

Prepare for expanding the functionality of this module by renaming it.

Signed-off-by: Brian Smith <brian@briansmith.org>

* Abstract I/O interface into a trait.

Instead of pattern matching over an `Io` variant, use a `Box<Io>` to
abstract the I/O interface. This will make it easier to add a TLS
transport.

Signed-off-by: Brian Smith <brian@briansmith.org>
This commit is contained in:
Brian Smith 2018-05-26 10:04:31 -10:00 committed by GitHub
parent 4cca72fb92
commit 79a38327d2
5 changed files with 81 additions and 72 deletions

View File

@ -12,6 +12,7 @@ use tokio::{
use config::Addr; use config::Addr;
use transport::GetOriginalDst; use transport::GetOriginalDst;
use transport::Io;
pub type PlaintextSocket = TcpStream; pub type PlaintextSocket = TcpStream;
@ -36,7 +37,7 @@ pub struct Connecting(ConnectFuture);
/// subverted. /// subverted.
#[derive(Debug)] #[derive(Debug)]
pub struct Connection { pub struct Connection {
io: Io, io: Box<Io>,
/// This buffer gets filled up when "peeking" bytes on this Connection. /// This buffer gets filled up when "peeking" bytes on this Connection.
/// ///
/// This is used instead of MSG_PEEK in order to support TLS streams. /// This is used instead of MSG_PEEK in order to support TLS streams.
@ -46,11 +47,6 @@ pub struct Connection {
peek_buf: BytesMut, peek_buf: BytesMut,
} }
#[derive(Debug)]
enum Io {
Plain(PlaintextSocket),
}
/// A trait describing that a type can peek bytes. /// A trait describing that a type can peek bytes.
pub trait Peek { pub trait Peek {
/// An async attempt to peek bytes of this type without consuming. /// An async attempt to peek bytes of this type without consuming.
@ -154,29 +150,17 @@ impl Connection {
/// A constructor of `Connection` with a plain text TCP socket. /// A constructor of `Connection` with a plain text TCP socket.
pub fn plain(socket: PlaintextSocket) -> Self { pub fn plain(socket: PlaintextSocket) -> Self {
Connection { Connection {
io: Io::Plain(socket), io: Box::new(socket),
peek_buf: BytesMut::new(), peek_buf: BytesMut::new(),
} }
} }
pub fn original_dst_addr<T: GetOriginalDst>(&self, get: &T) -> Option<SocketAddr> { pub fn original_dst_addr<T: GetOriginalDst>(&self, get: &T) -> Option<SocketAddr> {
get.get_original_dst(self.socket()) get.get_original_dst(&self.io)
} }
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> { pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
self.socket().local_addr() self.io.local_addr()
}
// This must never be made public so that in the future `Connection` can
// control access to the plaintext socket for TLS, to ensure no private
// data is accidentally writen to the socket and to ensure no unprotected
// data is read from the socket. Each piece of information needed about the
// underlying socket should be exposed by its own minimal accessor function
// as is done above.
fn socket(&self) -> &PlaintextSocket {
match self.io {
Io::Plain(ref socket) => socket
}
} }
} }
@ -187,9 +171,7 @@ impl io::Read for Connection {
let peeked_len = self.peek_buf.len(); let peeked_len = self.peek_buf.len();
if peeked_len == 0 { if peeked_len == 0 {
match self.io { self.io.read(buf)
Io::Plain(ref mut t) => t.read(buf),
}
} else { } else {
let len = cmp::min(buf.len(), peeked_len); let len = cmp::min(buf.len(), peeked_len);
buf[..len].copy_from_slice(&self.peek_buf.as_ref()[..len]); buf[..len].copy_from_slice(&self.peek_buf.as_ref()[..len]);
@ -207,47 +189,35 @@ impl io::Read for Connection {
impl AsyncRead for Connection { impl AsyncRead for Connection {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match self.io { self.io.prepare_uninitialized_buffer(buf)
Io::Plain(ref t) => t.prepare_uninitialized_buffer(buf),
}
} }
} }
impl io::Write for Connection { impl io::Write for Connection {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.io { self.io.write(buf)
Io::Plain(ref mut t) => t.write(buf),
}
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
match self.io { self.io.flush()
Io::Plain(ref mut t) => t.flush(),
}
} }
} }
impl AsyncWrite for Connection { impl AsyncWrite for Connection {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn shutdown(&mut self) -> Poll<(), io::Error> {
use std::net::Shutdown; try_ready!(AsyncWrite::shutdown(&mut self.io));
match self.io {
Io::Plain(ref mut t) => { // TCP shutdown the write side.
try_ready!(AsyncWrite::shutdown(t)); //
// TCP shutdown the write side. // If we're shutting down, then we definitely won't write
// // anymore. So, we should tell the remote about this. This
// If we're shutting down, then we definitely won't write // is relied upon in our TCP proxy, to start shutting down
// anymore. So, we should tell the remote about this. This // the pipe if one side closes.
// is relied upon in our TCP proxy, to start shutting down self.io.shutdown_write().map(Async::Ready)
// the pipe if one side closes.
TcpStream::shutdown(t, Shutdown::Write).map(Async::Ready)
},
}
} }
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> { fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
match self.io { self.io.write_buf(buf)
Io::Plain(ref mut t) => t.write_buf(buf),
}
} }
} }
@ -255,9 +225,7 @@ impl Peek for Connection {
fn poll_peek(&mut self) -> Poll<usize, io::Error> { fn poll_peek(&mut self) -> Poll<usize, io::Error> {
if self.peek_buf.is_empty() { if self.peek_buf.is_empty() {
self.peek_buf.reserve(8192); self.peek_buf.reserve(8192);
match self.io { self.io.read_buf(&mut self.peek_buf)
Io::Plain(ref mut t) => t.read_buf(&mut self.peek_buf),
}
} else { } else {
Ok(Async::Ready(self.peek_buf.len())) Ok(Async::Ready(self.peek_buf.len()))
} }

View File

@ -87,7 +87,7 @@ use inbound::Inbound;
use map_err::MapErr; use map_err::MapErr;
use task::MainRuntime; use task::MainRuntime;
use transparency::{HttpBody, Server}; use transparency::{HttpBody, Server};
pub use transport::{GetOriginalDst, SoOriginalDst}; pub use transport::{AddrInfo, GetOriginalDst, SoOriginalDst};
use outbound::Outbound; use outbound::Outbound;
/// Runs a sidecar proxy. /// Runs a sidecar proxy.

View File

@ -1,33 +1,59 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use std::fmt::Debug;
use std::io;
pub trait AddrInfo: Debug {
fn local_addr(&self) -> Result<SocketAddr, io::Error>;
fn get_original_dst(&self) -> Option<SocketAddr>;
}
impl<T: AddrInfo + ?Sized> AddrInfo for Box<T> {
fn local_addr(&self) -> Result<SocketAddr, io::Error> {
self.as_ref().local_addr()
}
fn get_original_dst(&self) -> Option<SocketAddr> {
self.as_ref().get_original_dst()
}
}
impl AddrInfo for TcpStream {
fn local_addr(&self) -> Result<SocketAddr, io::Error> {
TcpStream::local_addr(&self)
}
#[cfg(target_os = "linux")]
fn get_original_dst(&self) -> Option<SocketAddr> {
use self::linux;
use std::os::unix::io::AsRawFd;
let fd = self.as_raw_fd();
let r = unsafe { linux::so_original_dst(fd) };
r.ok()
}
#[cfg(not(target_os = "linux"))]
fn get_original_dst(&self) -> Option<SocketAddr> {
debug!("no support for SO_ORIGINAL_DST");
None
}
}
/// A generic way to get the original destination address of a socket. /// A generic way to get the original destination address of a socket.
/// ///
/// This is especially useful to allow tests to provide a mock implementation. /// This is especially useful to allow tests to provide a mock implementation.
pub trait GetOriginalDst { pub trait GetOriginalDst {
fn get_original_dst(&self, socket: &TcpStream) -> Option<SocketAddr>; fn get_original_dst(&self, socket: &AddrInfo) -> Option<SocketAddr>;
} }
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub struct SoOriginalDst; pub struct SoOriginalDst;
impl GetOriginalDst for SoOriginalDst { impl GetOriginalDst for SoOriginalDst {
#[cfg(not(target_os = "linux"))] fn get_original_dst(&self, sock: &AddrInfo) -> Option<SocketAddr> {
fn get_original_dst(&self, _: &TcpStream) -> Option<SocketAddr> {
debug!("no support for SO_ORIGINAL_DST");
None
}
// TODO change/remove once https://github.com/tokio-rs/tokio/issues/25 is addressed
#[cfg(target_os = "linux")]
fn get_original_dst(&self, sock: &TcpStream) -> Option<SocketAddr> {
use self::linux;
use std::os::unix::io::AsRawFd;
trace!("get_original_dst {:?}", sock); trace!("get_original_dst {:?}", sock);
sock.get_original_dst()
let res = unsafe { linux::so_original_dst(sock.as_raw_fd()) };
res.ok()
} }
} }

View File

@ -1,9 +1,24 @@
use std::io;
use std::net::Shutdown;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
mod connect; mod connect;
mod so_original_dst; mod addr_info;
pub use self::connect::{ pub use self::connect::{
Connect, Connect,
DnsNameAndPort, Host, HostAndPort, HostAndPortError, DnsNameAndPort, Host, HostAndPort, HostAndPortError,
LookupAddressAndConnect, LookupAddressAndConnect,
}; };
pub use self::so_original_dst::{GetOriginalDst, SoOriginalDst}; pub use self::addr_info::{AddrInfo, GetOriginalDst, SoOriginalDst};
pub trait Io: AddrInfo + AsyncRead + AsyncWrite + Send {
fn shutdown_write(&mut self) -> Result<(), io::Error>;
}
impl Io for TcpStream {
fn shutdown_write(&mut self) -> Result<(), io::Error> {
TcpStream::shutdown(self, Shutdown::Write)
}
}

View File

@ -105,7 +105,7 @@ struct DstInner {
} }
impl conduit_proxy::GetOriginalDst for MockOriginalDst { impl conduit_proxy::GetOriginalDst for MockOriginalDst {
fn get_original_dst(&self, sock: &TcpStream) -> Option<SocketAddr> { fn get_original_dst(&self, sock: &AddrInfo) -> Option<SocketAddr> {
sock.local_addr() sock.local_addr()
.ok() .ok()
.and_then(|local| { .and_then(|local| {