diff --git a/proxy/src/connection.rs b/proxy/src/connection.rs index 2bc5e7a0a..5723e2bbb 100644 --- a/proxy/src/connection.rs +++ b/proxy/src/connection.rs @@ -3,7 +3,7 @@ use std; use std::io; use std::net::SocketAddr; use tokio_core; -use tokio_core::net::TcpListener; +use tokio_core::net::{TcpListener, TcpStreamNew}; use tokio_core::reactor::Handle; use tokio_io::{AsyncRead, AsyncWrite}; @@ -17,7 +17,20 @@ pub struct BoundPort { local_addr: SocketAddr, } +/// Initiates a client connection to the given address. +pub fn connect(addr: &SocketAddr, executor: &Handle) -> Connecting { + Connecting(PlaintextSocket::connect(addr, executor)) +} + +/// A socket that is in the process of connecting. +pub struct Connecting(TcpStreamNew); + /// Abstracts a plaintext socket vs. a TLS decorated one. +/// +/// A `Connection` has the `TCP_NODELAY` option set automatically. Also +/// it strictly controls access to information about the underlying +/// socket to reduce the chance of TLS protections being accidentally +/// subverted. #[derive(Debug)] pub enum Connection { Plain(PlaintextSocket), @@ -60,14 +73,7 @@ impl BoundPort { // doesn't work on all platforms and also the underlying // libraries don't have the necessary API for that, so just // do it here. - if let Err(e) = socket.set_nodelay(true) { - warn!( - "could not set TCP_NODELAY on {:?}/{:?}: {}", - socket.local_addr(), - socket.peer_addr(), - e - ); - } + set_nodelay_or_warn(&socket); f(b, (Connection::Plain(socket), remote_addr)) }); @@ -75,6 +81,19 @@ impl BoundPort { } } +// ===== impl Connecting ===== + +impl Future for Connecting { + type Item = Connection; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + let socket = try_ready!(self.0.poll()); + set_nodelay_or_warn(&socket); + Ok(Async::Ready(Connection::Plain(socket))) + } +} + // ===== impl Connection ===== impl Connection { @@ -140,3 +159,16 @@ impl AsyncWrite for Connection { } } } + +// Misc. + +fn set_nodelay_or_warn(socket: &PlaintextSocket) { + if let Err(e) = socket.set_nodelay(true) { + warn!( + "could not set TCP_NODELAY on {:?}/{:?}: {}", + socket.local_addr(), + socket.peer_addr(), + e + ); + } +} diff --git a/proxy/src/transport/connect.rs b/proxy/src/transport/connect.rs index 82c800f0a..f51068cd6 100644 --- a/proxy/src/transport/connect.rs +++ b/proxy/src/transport/connect.rs @@ -1,18 +1,15 @@ -use futures::{Async, Future, Poll}; +use futures::Future; use tokio_connect; -use tokio_core::net::{TcpStream, TcpStreamNew}; use tokio_core::reactor::Handle; use url; use std::io; use std::net::{IpAddr, SocketAddr}; +use connection; use dns; use ::timeout; -#[must_use = "futures do nothing unless polled"] -pub struct TcpStreamNewNoDelay(TcpStreamNew); - #[derive(Debug, Clone)] pub struct Connect { addr: SocketAddr, @@ -29,26 +26,6 @@ pub struct LookupAddressAndConnect { pub type TimeoutConnect = timeout::Timeout; pub type TimeoutError = timeout::TimeoutError; -// ===== impl TcpStreamNewNoDelay ===== - -impl Future for TcpStreamNewNoDelay { - type Item = TcpStream; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - let tcp = try_ready!(self.0.poll()); - if let Err(e) = tcp.set_nodelay(true) { - warn!( - "could not set TCP_NODELAY on {:?}/{:?}: {}", - tcp.local_addr(), - tcp.peer_addr(), - e - ); - } - Ok(Async::Ready(tcp)) - } -} - // ===== impl Connect ===== impl Connect { @@ -62,13 +39,12 @@ impl Connect { } impl tokio_connect::Connect for Connect { - type Connected = TcpStream; + type Connected = connection::Connection; type Error = io::Error; - type Future = TcpStreamNewNoDelay; + type Future = connection::Connecting; fn connect(&self) -> Self::Future { - trace!("connect {}", self.addr); - TcpStreamNewNoDelay(TcpStream::connect(&self.addr, &self.handle)) + connection::connect(&self.addr, &self.handle) } } @@ -89,9 +65,9 @@ impl LookupAddressAndConnect { } impl tokio_connect::Connect for LookupAddressAndConnect { - type Connected = TcpStream; + type Connected = connection::Connection; type Error = io::Error; - type Future = Box>; + type Future = Box>; fn connect(&self) -> Self::Future { let port = self.host_and_port.port; @@ -106,8 +82,8 @@ impl tokio_connect::Connect for LookupAddressAndConnect { info!("DNS resolved {} to {}", host, ip_addr); let addr = SocketAddr::from((ip_addr, port)); trace!("connect {}", addr); - TcpStreamNewNoDelay(TcpStream::connect(&addr, &handle)) + connection::connect(&addr, &handle) }); Box::new(c) } -} \ No newline at end of file +}