Abstract I/O interface into a trait. (#1020)
* Rename so_original_dst.rs to addr_info.rs. Prepare for expanding the functionality of this module by renaming it. Signed-off-by: Brian Smith <brian@briansmith.org> * Abstract I/O interface into a trait. Instead of pattern matching over an `Io` variant, use a `Box<Io>` to abstract the I/O interface. This will make it easier to add a TLS transport. Signed-off-by: Brian Smith <brian@briansmith.org>
This commit is contained in:
parent
4cca72fb92
commit
79a38327d2
|
@ -12,6 +12,7 @@ use tokio::{
|
|||
|
||||
use config::Addr;
|
||||
use transport::GetOriginalDst;
|
||||
use transport::Io;
|
||||
|
||||
pub type PlaintextSocket = TcpStream;
|
||||
|
||||
|
@ -36,7 +37,7 @@ pub struct Connecting(ConnectFuture);
|
|||
/// subverted.
|
||||
#[derive(Debug)]
|
||||
pub struct Connection {
|
||||
io: Io,
|
||||
io: Box<Io>,
|
||||
/// This buffer gets filled up when "peeking" bytes on this Connection.
|
||||
///
|
||||
/// This is used instead of MSG_PEEK in order to support TLS streams.
|
||||
|
@ -46,11 +47,6 @@ pub struct Connection {
|
|||
peek_buf: BytesMut,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Io {
|
||||
Plain(PlaintextSocket),
|
||||
}
|
||||
|
||||
/// A trait describing that a type can peek bytes.
|
||||
pub trait Peek {
|
||||
/// An async attempt to peek bytes of this type without consuming.
|
||||
|
@ -154,29 +150,17 @@ impl Connection {
|
|||
/// A constructor of `Connection` with a plain text TCP socket.
|
||||
pub fn plain(socket: PlaintextSocket) -> Self {
|
||||
Connection {
|
||||
io: Io::Plain(socket),
|
||||
io: Box::new(socket),
|
||||
peek_buf: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn original_dst_addr<T: GetOriginalDst>(&self, get: &T) -> Option<SocketAddr> {
|
||||
get.get_original_dst(self.socket())
|
||||
get.get_original_dst(&self.io)
|
||||
}
|
||||
|
||||
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
|
||||
self.socket().local_addr()
|
||||
}
|
||||
|
||||
// This must never be made public so that in the future `Connection` can
|
||||
// control access to the plaintext socket for TLS, to ensure no private
|
||||
// data is accidentally writen to the socket and to ensure no unprotected
|
||||
// data is read from the socket. Each piece of information needed about the
|
||||
// underlying socket should be exposed by its own minimal accessor function
|
||||
// as is done above.
|
||||
fn socket(&self) -> &PlaintextSocket {
|
||||
match self.io {
|
||||
Io::Plain(ref socket) => socket
|
||||
}
|
||||
self.io.local_addr()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -187,9 +171,7 @@ impl io::Read for Connection {
|
|||
let peeked_len = self.peek_buf.len();
|
||||
|
||||
if peeked_len == 0 {
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => t.read(buf),
|
||||
}
|
||||
self.io.read(buf)
|
||||
} else {
|
||||
let len = cmp::min(buf.len(), peeked_len);
|
||||
buf[..len].copy_from_slice(&self.peek_buf.as_ref()[..len]);
|
||||
|
@ -207,47 +189,35 @@ impl io::Read for Connection {
|
|||
|
||||
impl AsyncRead for Connection {
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||
match self.io {
|
||||
Io::Plain(ref t) => t.prepare_uninitialized_buffer(buf),
|
||||
}
|
||||
self.io.prepare_uninitialized_buffer(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for Connection {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => t.write(buf),
|
||||
}
|
||||
self.io.write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => t.flush(),
|
||||
}
|
||||
self.io.flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Connection {
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
use std::net::Shutdown;
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => {
|
||||
try_ready!(AsyncWrite::shutdown(t));
|
||||
// TCP shutdown the write side.
|
||||
//
|
||||
// If we're shutting down, then we definitely won't write
|
||||
// anymore. So, we should tell the remote about this. This
|
||||
// is relied upon in our TCP proxy, to start shutting down
|
||||
// the pipe if one side closes.
|
||||
TcpStream::shutdown(t, Shutdown::Write).map(Async::Ready)
|
||||
},
|
||||
}
|
||||
try_ready!(AsyncWrite::shutdown(&mut self.io));
|
||||
|
||||
// TCP shutdown the write side.
|
||||
//
|
||||
// If we're shutting down, then we definitely won't write
|
||||
// anymore. So, we should tell the remote about this. This
|
||||
// is relied upon in our TCP proxy, to start shutting down
|
||||
// the pipe if one side closes.
|
||||
self.io.shutdown_write().map(Async::Ready)
|
||||
}
|
||||
|
||||
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => t.write_buf(buf),
|
||||
}
|
||||
self.io.write_buf(buf)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -255,9 +225,7 @@ impl Peek for Connection {
|
|||
fn poll_peek(&mut self) -> Poll<usize, io::Error> {
|
||||
if self.peek_buf.is_empty() {
|
||||
self.peek_buf.reserve(8192);
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => t.read_buf(&mut self.peek_buf),
|
||||
}
|
||||
self.io.read_buf(&mut self.peek_buf)
|
||||
} else {
|
||||
Ok(Async::Ready(self.peek_buf.len()))
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ use inbound::Inbound;
|
|||
use map_err::MapErr;
|
||||
use task::MainRuntime;
|
||||
use transparency::{HttpBody, Server};
|
||||
pub use transport::{GetOriginalDst, SoOriginalDst};
|
||||
pub use transport::{AddrInfo, GetOriginalDst, SoOriginalDst};
|
||||
use outbound::Outbound;
|
||||
|
||||
/// Runs a sidecar proxy.
|
||||
|
|
|
@ -1,33 +1,59 @@
|
|||
use std::net::SocketAddr;
|
||||
use tokio::net::TcpStream;
|
||||
use std::fmt::Debug;
|
||||
use std::io;
|
||||
|
||||
pub trait AddrInfo: Debug {
|
||||
fn local_addr(&self) -> Result<SocketAddr, io::Error>;
|
||||
fn get_original_dst(&self) -> Option<SocketAddr>;
|
||||
}
|
||||
|
||||
impl<T: AddrInfo + ?Sized> AddrInfo for Box<T> {
|
||||
fn local_addr(&self) -> Result<SocketAddr, io::Error> {
|
||||
self.as_ref().local_addr()
|
||||
}
|
||||
|
||||
fn get_original_dst(&self) -> Option<SocketAddr> {
|
||||
self.as_ref().get_original_dst()
|
||||
}
|
||||
}
|
||||
|
||||
impl AddrInfo for TcpStream {
|
||||
fn local_addr(&self) -> Result<SocketAddr, io::Error> {
|
||||
TcpStream::local_addr(&self)
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
fn get_original_dst(&self) -> Option<SocketAddr> {
|
||||
use self::linux;
|
||||
use std::os::unix::io::AsRawFd;
|
||||
|
||||
let fd = self.as_raw_fd();
|
||||
let r = unsafe { linux::so_original_dst(fd) };
|
||||
r.ok()
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
fn get_original_dst(&self) -> Option<SocketAddr> {
|
||||
debug!("no support for SO_ORIGINAL_DST");
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// A generic way to get the original destination address of a socket.
|
||||
///
|
||||
/// This is especially useful to allow tests to provide a mock implementation.
|
||||
pub trait GetOriginalDst {
|
||||
fn get_original_dst(&self, socket: &TcpStream) -> Option<SocketAddr>;
|
||||
fn get_original_dst(&self, socket: &AddrInfo) -> Option<SocketAddr>;
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct SoOriginalDst;
|
||||
|
||||
impl GetOriginalDst for SoOriginalDst {
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
fn get_original_dst(&self, _: &TcpStream) -> Option<SocketAddr> {
|
||||
debug!("no support for SO_ORIGINAL_DST");
|
||||
None
|
||||
}
|
||||
|
||||
// TODO change/remove once https://github.com/tokio-rs/tokio/issues/25 is addressed
|
||||
#[cfg(target_os = "linux")]
|
||||
fn get_original_dst(&self, sock: &TcpStream) -> Option<SocketAddr> {
|
||||
use self::linux;
|
||||
use std::os::unix::io::AsRawFd;
|
||||
|
||||
fn get_original_dst(&self, sock: &AddrInfo) -> Option<SocketAddr> {
|
||||
trace!("get_original_dst {:?}", sock);
|
||||
|
||||
let res = unsafe { linux::so_original_dst(sock.as_raw_fd()) };
|
||||
res.ok()
|
||||
sock.get_original_dst()
|
||||
}
|
||||
}
|
||||
|
|
@ -1,9 +1,24 @@
|
|||
use std::io;
|
||||
use std::net::Shutdown;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
mod connect;
|
||||
mod so_original_dst;
|
||||
mod addr_info;
|
||||
|
||||
pub use self::connect::{
|
||||
Connect,
|
||||
DnsNameAndPort, Host, HostAndPort, HostAndPortError,
|
||||
LookupAddressAndConnect,
|
||||
};
|
||||
pub use self::so_original_dst::{GetOriginalDst, SoOriginalDst};
|
||||
pub use self::addr_info::{AddrInfo, GetOriginalDst, SoOriginalDst};
|
||||
|
||||
pub trait Io: AddrInfo + AsyncRead + AsyncWrite + Send {
|
||||
fn shutdown_write(&mut self) -> Result<(), io::Error>;
|
||||
}
|
||||
|
||||
impl Io for TcpStream {
|
||||
fn shutdown_write(&mut self) -> Result<(), io::Error> {
|
||||
TcpStream::shutdown(self, Shutdown::Write)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -105,7 +105,7 @@ struct DstInner {
|
|||
}
|
||||
|
||||
impl conduit_proxy::GetOriginalDst for MockOriginalDst {
|
||||
fn get_original_dst(&self, sock: &TcpStream) -> Option<SocketAddr> {
|
||||
fn get_original_dst(&self, sock: &AddrInfo) -> Option<SocketAddr> {
|
||||
sock.local_addr()
|
||||
.ok()
|
||||
.and_then(|local| {
|
||||
|
|
Loading…
Reference in New Issue