Abstract I/O interface into a trait. (#1020)

* Rename so_original_dst.rs to addr_info.rs.

Prepare for expanding the functionality of this module by renaming it.

Signed-off-by: Brian Smith <brian@briansmith.org>

* Abstract I/O interface into a trait.

Instead of pattern matching over an `Io` variant, use a `Box<Io>` to
abstract the I/O interface. This will make it easier to add a TLS
transport.

Signed-off-by: Brian Smith <brian@briansmith.org>
This commit is contained in:
Brian Smith 2018-05-26 10:04:31 -10:00 committed by GitHub
parent 4cca72fb92
commit 79a38327d2
5 changed files with 81 additions and 72 deletions

View File

@ -12,6 +12,7 @@ use tokio::{
use config::Addr;
use transport::GetOriginalDst;
use transport::Io;
pub type PlaintextSocket = TcpStream;
@ -36,7 +37,7 @@ pub struct Connecting(ConnectFuture);
/// subverted.
#[derive(Debug)]
pub struct Connection {
io: Io,
io: Box<Io>,
/// This buffer gets filled up when "peeking" bytes on this Connection.
///
/// This is used instead of MSG_PEEK in order to support TLS streams.
@ -46,11 +47,6 @@ pub struct Connection {
peek_buf: BytesMut,
}
#[derive(Debug)]
enum Io {
Plain(PlaintextSocket),
}
/// A trait describing that a type can peek bytes.
pub trait Peek {
/// An async attempt to peek bytes of this type without consuming.
@ -154,29 +150,17 @@ impl Connection {
/// A constructor of `Connection` with a plain text TCP socket.
pub fn plain(socket: PlaintextSocket) -> Self {
Connection {
io: Io::Plain(socket),
io: Box::new(socket),
peek_buf: BytesMut::new(),
}
}
pub fn original_dst_addr<T: GetOriginalDst>(&self, get: &T) -> Option<SocketAddr> {
get.get_original_dst(self.socket())
get.get_original_dst(&self.io)
}
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
self.socket().local_addr()
}
// This must never be made public so that in the future `Connection` can
// control access to the plaintext socket for TLS, to ensure no private
// data is accidentally writen to the socket and to ensure no unprotected
// data is read from the socket. Each piece of information needed about the
// underlying socket should be exposed by its own minimal accessor function
// as is done above.
fn socket(&self) -> &PlaintextSocket {
match self.io {
Io::Plain(ref socket) => socket
}
self.io.local_addr()
}
}
@ -187,9 +171,7 @@ impl io::Read for Connection {
let peeked_len = self.peek_buf.len();
if peeked_len == 0 {
match self.io {
Io::Plain(ref mut t) => t.read(buf),
}
self.io.read(buf)
} else {
let len = cmp::min(buf.len(), peeked_len);
buf[..len].copy_from_slice(&self.peek_buf.as_ref()[..len]);
@ -207,47 +189,35 @@ impl io::Read for Connection {
impl AsyncRead for Connection {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match self.io {
Io::Plain(ref t) => t.prepare_uninitialized_buffer(buf),
}
self.io.prepare_uninitialized_buffer(buf)
}
}
impl io::Write for Connection {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.io {
Io::Plain(ref mut t) => t.write(buf),
}
self.io.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
match self.io {
Io::Plain(ref mut t) => t.flush(),
}
self.io.flush()
}
}
impl AsyncWrite for Connection {
fn shutdown(&mut self) -> Poll<(), io::Error> {
use std::net::Shutdown;
match self.io {
Io::Plain(ref mut t) => {
try_ready!(AsyncWrite::shutdown(t));
// TCP shutdown the write side.
//
// If we're shutting down, then we definitely won't write
// anymore. So, we should tell the remote about this. This
// is relied upon in our TCP proxy, to start shutting down
// the pipe if one side closes.
TcpStream::shutdown(t, Shutdown::Write).map(Async::Ready)
},
}
try_ready!(AsyncWrite::shutdown(&mut self.io));
// TCP shutdown the write side.
//
// If we're shutting down, then we definitely won't write
// anymore. So, we should tell the remote about this. This
// is relied upon in our TCP proxy, to start shutting down
// the pipe if one side closes.
self.io.shutdown_write().map(Async::Ready)
}
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
match self.io {
Io::Plain(ref mut t) => t.write_buf(buf),
}
self.io.write_buf(buf)
}
}
@ -255,9 +225,7 @@ impl Peek for Connection {
fn poll_peek(&mut self) -> Poll<usize, io::Error> {
if self.peek_buf.is_empty() {
self.peek_buf.reserve(8192);
match self.io {
Io::Plain(ref mut t) => t.read_buf(&mut self.peek_buf),
}
self.io.read_buf(&mut self.peek_buf)
} else {
Ok(Async::Ready(self.peek_buf.len()))
}

View File

@ -87,7 +87,7 @@ use inbound::Inbound;
use map_err::MapErr;
use task::MainRuntime;
use transparency::{HttpBody, Server};
pub use transport::{GetOriginalDst, SoOriginalDst};
pub use transport::{AddrInfo, GetOriginalDst, SoOriginalDst};
use outbound::Outbound;
/// Runs a sidecar proxy.

View File

@ -1,33 +1,59 @@
use std::net::SocketAddr;
use tokio::net::TcpStream;
use std::fmt::Debug;
use std::io;
pub trait AddrInfo: Debug {
fn local_addr(&self) -> Result<SocketAddr, io::Error>;
fn get_original_dst(&self) -> Option<SocketAddr>;
}
impl<T: AddrInfo + ?Sized> AddrInfo for Box<T> {
fn local_addr(&self) -> Result<SocketAddr, io::Error> {
self.as_ref().local_addr()
}
fn get_original_dst(&self) -> Option<SocketAddr> {
self.as_ref().get_original_dst()
}
}
impl AddrInfo for TcpStream {
fn local_addr(&self) -> Result<SocketAddr, io::Error> {
TcpStream::local_addr(&self)
}
#[cfg(target_os = "linux")]
fn get_original_dst(&self) -> Option<SocketAddr> {
use self::linux;
use std::os::unix::io::AsRawFd;
let fd = self.as_raw_fd();
let r = unsafe { linux::so_original_dst(fd) };
r.ok()
}
#[cfg(not(target_os = "linux"))]
fn get_original_dst(&self) -> Option<SocketAddr> {
debug!("no support for SO_ORIGINAL_DST");
None
}
}
/// A generic way to get the original destination address of a socket.
///
/// This is especially useful to allow tests to provide a mock implementation.
pub trait GetOriginalDst {
fn get_original_dst(&self, socket: &TcpStream) -> Option<SocketAddr>;
fn get_original_dst(&self, socket: &AddrInfo) -> Option<SocketAddr>;
}
#[derive(Copy, Clone, Debug)]
pub struct SoOriginalDst;
impl GetOriginalDst for SoOriginalDst {
#[cfg(not(target_os = "linux"))]
fn get_original_dst(&self, _: &TcpStream) -> Option<SocketAddr> {
debug!("no support for SO_ORIGINAL_DST");
None
}
// TODO change/remove once https://github.com/tokio-rs/tokio/issues/25 is addressed
#[cfg(target_os = "linux")]
fn get_original_dst(&self, sock: &TcpStream) -> Option<SocketAddr> {
use self::linux;
use std::os::unix::io::AsRawFd;
fn get_original_dst(&self, sock: &AddrInfo) -> Option<SocketAddr> {
trace!("get_original_dst {:?}", sock);
let res = unsafe { linux::so_original_dst(sock.as_raw_fd()) };
res.ok()
sock.get_original_dst()
}
}

View File

@ -1,9 +1,24 @@
use std::io;
use std::net::Shutdown;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
mod connect;
mod so_original_dst;
mod addr_info;
pub use self::connect::{
Connect,
DnsNameAndPort, Host, HostAndPort, HostAndPortError,
LookupAddressAndConnect,
};
pub use self::so_original_dst::{GetOriginalDst, SoOriginalDst};
pub use self::addr_info::{AddrInfo, GetOriginalDst, SoOriginalDst};
pub trait Io: AddrInfo + AsyncRead + AsyncWrite + Send {
fn shutdown_write(&mut self) -> Result<(), io::Error>;
}
impl Io for TcpStream {
fn shutdown_write(&mut self) -> Result<(), io::Error> {
TcpStream::shutdown(self, Shutdown::Write)
}
}

View File

@ -105,7 +105,7 @@ struct DstInner {
}
impl conduit_proxy::GetOriginalDst for MockOriginalDst {
fn get_original_dst(&self, sock: &TcpStream) -> Option<SocketAddr> {
fn get_original_dst(&self, sock: &AddrInfo) -> Option<SocketAddr> {
sock.local_addr()
.ok()
.and_then(|local| {