diff --git a/proxy/src/connection.rs b/proxy/src/connection.rs index c35665ab1..7070d9cbf 100644 --- a/proxy/src/connection.rs +++ b/proxy/src/connection.rs @@ -1,6 +1,7 @@ -use bytes::Buf; +use bytes::{Buf, BytesMut}; use futures::*; use std; +use std::cmp; use std::io; use std::net::SocketAddr; use tokio_core::net::{TcpListener, TcpStreamNew, TcpStream}; @@ -32,19 +33,48 @@ pub struct Connecting(TcpStreamNew); /// socket to reduce the chance of TLS protections being accidentally /// subverted. #[derive(Debug)] -pub enum Connection { +pub struct Connection { + io: Io, + /// This buffer gets filled up when "peeking" bytes on this Connection. + /// + /// This is used instead of MSG_PEEK in order to support TLS streams. + /// + /// When calling `read`, it's important to consume bytes from this buffer + /// before calling `io.read`. + peek_buf: BytesMut, +} + +#[derive(Debug)] +enum Io { Plain(PlaintextSocket), } -/// A trait describing that a type can peek (such as MSG_PEEK). +/// A trait describing that a type can peek bytes. pub trait Peek { - fn peek(&mut self, buf: &mut [u8]) -> io::Result; + /// An async attempt to peek bytes of this type without consuming. + /// + /// Returns number of bytes that have been peeked. + fn poll_peek(&mut self) -> Poll; + + /// Returns a reference to the bytes that have been peeked. + // Instead of passing a buffer into `peek()`, the bytes are kept in + // a buffer owned by the `Peek` type. This allows looking at the + // peeked bytes cheaply, instead of needing to copy into a new + // buffer. + fn peeked(&self) -> &[u8]; + + /// A `Future` around `poll_peek`, returning this type instead. + fn peek(self) -> PeekFuture where Self: Sized { + PeekFuture { + inner: Some(self), + } + } } /// A future of when some `Peek` fulfills with some bytes. #[derive(Debug)] -pub struct PeekFuture { - inner: Option<(T, B)>, +pub struct PeekFuture { + inner: Option, } // ===== impl BoundPort ===== @@ -85,7 +115,7 @@ impl BoundPort { // libraries don't have the necessary API for that, so just // do it here. set_nodelay_or_warn(&socket); - f(b, (Connection::Plain(socket), remote_addr)) + f(b, (Connection::plain(socket), remote_addr)) }); Box::new(fut.map(|_| ())) @@ -101,13 +131,21 @@ impl Future for Connecting { fn poll(&mut self) -> Poll { let socket = try_ready!(self.0.poll()); set_nodelay_or_warn(&socket); - Ok(Async::Ready(Connection::Plain(socket))) + Ok(Async::Ready(Connection::plain(socket))) } } // ===== impl Connection ===== impl Connection { + /// A constructor of `Connection` with a plain text TCP socket. + pub fn plain(socket: PlaintextSocket) -> Self { + Connection { + io: Io::Plain(socket), + peek_buf: BytesMut::new(), + } + } + pub fn original_dst_addr(&self, get: &T) -> Option { get.get_original_dst(self.socket()) } @@ -123,46 +161,55 @@ impl Connection { // underlying socket should be exposed by its own minimal accessor function // as is done above. fn socket(&self) -> &PlaintextSocket { - match self { - &Connection::Plain(ref socket) => socket + match self.io { + Io::Plain(ref socket) => socket } } } impl io::Read for Connection { fn read(&mut self, buf: &mut [u8]) -> io::Result { - use self::Connection::*; + // Check the length only once, since looking as the length + // of a BytesMut isn't as cheap as the length of a &[u8]. + let peeked_len = self.peek_buf.len(); - match *self { - Plain(ref mut t) => t.read(buf), + if peeked_len == 0 { + match self.io { + Io::Plain(ref mut t) => t.read(buf), + } + } else { + let len = cmp::min(buf.len(), peeked_len); + buf[..len].copy_from_slice(&self.peek_buf.as_ref()[..len]); + self.peek_buf.advance(len); + // If we've finally emptied the peek_buf, drop it so we don't + // hold onto the allocated memory any longer. We won't peek + // again. + if peeked_len == len { + self.peek_buf = BytesMut::new(); + } + Ok(len) } } } impl AsyncRead for Connection { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - use self::Connection::*; - - match *self { - Plain(ref t) => t.prepare_uninitialized_buffer(buf), + match self.io { + Io::Plain(ref t) => t.prepare_uninitialized_buffer(buf), } } } impl io::Write for Connection { fn write(&mut self, buf: &[u8]) -> io::Result { - use self::Connection::*; - - match *self { - Plain(ref mut t) => t.write(buf), + match self.io { + Io::Plain(ref mut t) => t.write(buf), } } fn flush(&mut self) -> io::Result<()> { - use self::Connection::*; - - match *self { - Plain(ref mut t) => t.flush(), + match self.io { + Io::Plain(ref mut t) => t.flush(), } } } @@ -170,10 +217,8 @@ impl io::Write for Connection { impl AsyncWrite for Connection { fn shutdown(&mut self) -> Poll<(), io::Error> { use std::net::Shutdown; - use self::Connection::*; - - match *self { - Plain(ref mut t) => { + match self.io { + Io::Plain(ref mut t) => { try_ready!(AsyncWrite::shutdown(t)); // TCP shutdown the write side. // @@ -187,49 +232,44 @@ impl AsyncWrite for Connection { } fn write_buf(&mut self, buf: &mut B) -> Poll { - use self::Connection::*; - - match *self { - Plain(ref mut t) => t.write_buf(buf), + match self.io { + Io::Plain(ref mut t) => t.write_buf(buf), } } } impl Peek for Connection { - fn peek(&mut self, buf: &mut [u8]) -> io::Result { - use self::Connection::*; - - match *self { - Plain(ref mut t) => t.peek(buf), + fn poll_peek(&mut self) -> Poll { + if self.peek_buf.is_empty() { + self.peek_buf.reserve(8192); + match self.io { + Io::Plain(ref mut t) => t.read_buf(&mut self.peek_buf), + } + } else { + Ok(Async::Ready(self.peek_buf.len())) } } + + fn peeked(&self) -> &[u8] { + self.peek_buf.as_ref() + } } // impl PeekFuture -impl> PeekFuture { - pub fn new(io: T, buf: B) -> Self { - PeekFuture { - inner: Some((io, buf)), - } - } -} - -impl> Future for PeekFuture { - type Item = (T, B, usize); +impl Future for PeekFuture { + type Item = T; type Error = std::io::Error; fn poll(&mut self) -> Poll { - let (mut io, mut buf) = self.inner.take().expect("polled after completed"); - match io.peek(buf.as_mut()) { - Ok(n) => Ok(Async::Ready((io, buf, n))), - Err(e) => match e.kind() { - std::io::ErrorKind::WouldBlock => { - self.inner = Some((io, buf)); - Ok(Async::NotReady) - }, - _ => Err(e) + let mut io = self.inner.take().expect("polled after completed"); + match io.poll_peek() { + Ok(Async::Ready(_)) => Ok(Async::Ready(io)), + Ok(Async::NotReady) => { + self.inner = Some(io); + Ok(Async::NotReady) }, + Err(e) => Err(e), } } } diff --git a/proxy/src/telemetry/sensor/transport.rs b/proxy/src/telemetry/sensor/transport.rs index 3f56ef4c4..dbab4fcfc 100644 --- a/proxy/src/telemetry/sensor/transport.rs +++ b/proxy/src/telemetry/sensor/transport.rs @@ -180,8 +180,12 @@ impl AsyncWrite for Transport { } impl Peek for Transport { - fn peek(&mut self, buf: &mut [u8]) -> io::Result { - self.sense_err(|io| io.peek(buf)) + fn poll_peek(&mut self) -> Poll { + self.sense_err(|io| io.poll_peek()) + } + + fn peeked(&self) -> &[u8] { + self.0.peeked() } } diff --git a/proxy/src/transparency/server.rs b/proxy/src/transparency/server.rs index 38ee72b24..23a3d8f62 100644 --- a/proxy/src/transparency/server.rs +++ b/proxy/src/transparency/server.rs @@ -12,7 +12,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use tower_service::NewService; use tower_h2; -use connection::{Connection, PeekFuture}; +use connection::{Connection, Peek}; use ctx::Proxy as ProxyCtx; use ctx::transport::{Server as ServerCtx}; use drain; @@ -131,16 +131,15 @@ where } // try to sniff protocol - let sniff = [0u8; 32]; let h1 = self.h1.clone(); let h2 = self.h2.clone(); let tcp = self.tcp.clone(); let new_service = self.new_service.clone(); let drain_signal = self.drain_signal.clone(); - let fut = PeekFuture::new(io, sniff) + let fut = io.peek() .map_err(|e| debug!("peek error: {}", e)) - .and_then(move |(io, sniff, n)| -> Box> { - if let Some(proto) = Protocol::detect(&sniff[..n]) { + .and_then(move |io| -> Box> { + if let Some(proto) = Protocol::detect(io.peeked()) { match proto { Protocol::Http1 => { trace!("transparency detected HTTP/1"); diff --git a/proxy/src/transparency/tcp.rs b/proxy/src/transparency/tcp.rs index 4c3e96398..32ac3d063 100644 --- a/proxy/src/transparency/tcp.rs +++ b/proxy/src/transparency/tcp.rs @@ -69,7 +69,7 @@ impl Proxy { let connect = self.sensors.connect(c, &client_ctx); let fut = connect.connect() - .map_err(|e| debug!("tcp connect error: {:?}", e)) + .map_err(move |e| error!("tcp connect error to {}: {:?}", orig_dst, e)) .and_then(move |tcp_out| { Duplex::new(tcp_in, tcp_out) .map_err(|e| error!("tcp duplex error: {}", e))