proxy: change peek to use reads for eventual support of TLS (#901)

This commit is contained in:
Sean McArthur 2018-05-08 18:19:12 -07:00 committed by GitHub
parent 50cb2f84db
commit 011d2541eb
4 changed files with 108 additions and 65 deletions

View File

@ -1,6 +1,7 @@
use bytes::Buf; use bytes::{Buf, BytesMut};
use futures::*; use futures::*;
use std; use std;
use std::cmp;
use std::io; use std::io;
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio_core::net::{TcpListener, TcpStreamNew, TcpStream}; use tokio_core::net::{TcpListener, TcpStreamNew, TcpStream};
@ -32,19 +33,48 @@ pub struct Connecting(TcpStreamNew);
/// socket to reduce the chance of TLS protections being accidentally /// socket to reduce the chance of TLS protections being accidentally
/// subverted. /// subverted.
#[derive(Debug)] #[derive(Debug)]
pub enum Connection { pub struct Connection {
io: Io,
/// This buffer gets filled up when "peeking" bytes on this Connection.
///
/// This is used instead of MSG_PEEK in order to support TLS streams.
///
/// When calling `read`, it's important to consume bytes from this buffer
/// before calling `io.read`.
peek_buf: BytesMut,
}
#[derive(Debug)]
enum Io {
Plain(PlaintextSocket), Plain(PlaintextSocket),
} }
/// A trait describing that a type can peek (such as MSG_PEEK). /// A trait describing that a type can peek bytes.
pub trait Peek { pub trait Peek {
fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize>; /// An async attempt to peek bytes of this type without consuming.
///
/// Returns number of bytes that have been peeked.
fn poll_peek(&mut self) -> Poll<usize, io::Error>;
/// Returns a reference to the bytes that have been peeked.
// Instead of passing a buffer into `peek()`, the bytes are kept in
// a buffer owned by the `Peek` type. This allows looking at the
// peeked bytes cheaply, instead of needing to copy into a new
// buffer.
fn peeked(&self) -> &[u8];
/// A `Future` around `poll_peek`, returning this type instead.
fn peek(self) -> PeekFuture<Self> where Self: Sized {
PeekFuture {
inner: Some(self),
}
}
} }
/// A future of when some `Peek` fulfills with some bytes. /// A future of when some `Peek` fulfills with some bytes.
#[derive(Debug)] #[derive(Debug)]
pub struct PeekFuture<T, B> { pub struct PeekFuture<T> {
inner: Option<(T, B)>, inner: Option<T>,
} }
// ===== impl BoundPort ===== // ===== impl BoundPort =====
@ -85,7 +115,7 @@ impl BoundPort {
// libraries don't have the necessary API for that, so just // libraries don't have the necessary API for that, so just
// do it here. // do it here.
set_nodelay_or_warn(&socket); set_nodelay_or_warn(&socket);
f(b, (Connection::Plain(socket), remote_addr)) f(b, (Connection::plain(socket), remote_addr))
}); });
Box::new(fut.map(|_| ())) Box::new(fut.map(|_| ()))
@ -101,13 +131,21 @@ impl Future for Connecting {
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let socket = try_ready!(self.0.poll()); let socket = try_ready!(self.0.poll());
set_nodelay_or_warn(&socket); set_nodelay_or_warn(&socket);
Ok(Async::Ready(Connection::Plain(socket))) Ok(Async::Ready(Connection::plain(socket)))
} }
} }
// ===== impl Connection ===== // ===== impl Connection =====
impl Connection { impl Connection {
/// A constructor of `Connection` with a plain text TCP socket.
pub fn plain(socket: PlaintextSocket) -> Self {
Connection {
io: Io::Plain(socket),
peek_buf: BytesMut::new(),
}
}
pub fn original_dst_addr<T: GetOriginalDst>(&self, get: &T) -> Option<SocketAddr> { pub fn original_dst_addr<T: GetOriginalDst>(&self, get: &T) -> Option<SocketAddr> {
get.get_original_dst(self.socket()) get.get_original_dst(self.socket())
} }
@ -123,46 +161,55 @@ impl Connection {
// underlying socket should be exposed by its own minimal accessor function // underlying socket should be exposed by its own minimal accessor function
// as is done above. // as is done above.
fn socket(&self) -> &PlaintextSocket { fn socket(&self) -> &PlaintextSocket {
match self { match self.io {
&Connection::Plain(ref socket) => socket Io::Plain(ref socket) => socket
} }
} }
} }
impl io::Read for Connection { impl io::Read for Connection {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
use self::Connection::*; // Check the length only once, since looking as the length
// of a BytesMut isn't as cheap as the length of a &[u8].
let peeked_len = self.peek_buf.len();
match *self { if peeked_len == 0 {
Plain(ref mut t) => t.read(buf), match self.io {
Io::Plain(ref mut t) => t.read(buf),
}
} else {
let len = cmp::min(buf.len(), peeked_len);
buf[..len].copy_from_slice(&self.peek_buf.as_ref()[..len]);
self.peek_buf.advance(len);
// If we've finally emptied the peek_buf, drop it so we don't
// hold onto the allocated memory any longer. We won't peek
// again.
if peeked_len == len {
self.peek_buf = BytesMut::new();
}
Ok(len)
} }
} }
} }
impl AsyncRead for Connection { impl AsyncRead for Connection {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
use self::Connection::*; match self.io {
Io::Plain(ref t) => t.prepare_uninitialized_buffer(buf),
match *self {
Plain(ref t) => t.prepare_uninitialized_buffer(buf),
} }
} }
} }
impl io::Write for Connection { impl io::Write for Connection {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
use self::Connection::*; match self.io {
Io::Plain(ref mut t) => t.write(buf),
match *self {
Plain(ref mut t) => t.write(buf),
} }
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
use self::Connection::*; match self.io {
Io::Plain(ref mut t) => t.flush(),
match *self {
Plain(ref mut t) => t.flush(),
} }
} }
} }
@ -170,10 +217,8 @@ impl io::Write for Connection {
impl AsyncWrite for Connection { impl AsyncWrite for Connection {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn shutdown(&mut self) -> Poll<(), io::Error> {
use std::net::Shutdown; use std::net::Shutdown;
use self::Connection::*; match self.io {
Io::Plain(ref mut t) => {
match *self {
Plain(ref mut t) => {
try_ready!(AsyncWrite::shutdown(t)); try_ready!(AsyncWrite::shutdown(t));
// TCP shutdown the write side. // TCP shutdown the write side.
// //
@ -187,49 +232,44 @@ impl AsyncWrite for Connection {
} }
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> { fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
use self::Connection::*; match self.io {
Io::Plain(ref mut t) => t.write_buf(buf),
match *self {
Plain(ref mut t) => t.write_buf(buf),
} }
} }
} }
impl Peek for Connection { impl Peek for Connection {
fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn poll_peek(&mut self) -> Poll<usize, io::Error> {
use self::Connection::*; if self.peek_buf.is_empty() {
self.peek_buf.reserve(8192);
match *self { match self.io {
Plain(ref mut t) => t.peek(buf), Io::Plain(ref mut t) => t.read_buf(&mut self.peek_buf),
}
} else {
Ok(Async::Ready(self.peek_buf.len()))
} }
} }
fn peeked(&self) -> &[u8] {
self.peek_buf.as_ref()
}
} }
// impl PeekFuture // impl PeekFuture
impl<T: Peek, B: AsMut<[u8]>> PeekFuture<T, B> { impl<T: Peek> Future for PeekFuture<T> {
pub fn new(io: T, buf: B) -> Self { type Item = T;
PeekFuture {
inner: Some((io, buf)),
}
}
}
impl<T: Peek, B: AsMut<[u8]>> Future for PeekFuture<T, B> {
type Item = (T, B, usize);
type Error = std::io::Error; type Error = std::io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let (mut io, mut buf) = self.inner.take().expect("polled after completed"); let mut io = self.inner.take().expect("polled after completed");
match io.peek(buf.as_mut()) { match io.poll_peek() {
Ok(n) => Ok(Async::Ready((io, buf, n))), Ok(Async::Ready(_)) => Ok(Async::Ready(io)),
Err(e) => match e.kind() { Ok(Async::NotReady) => {
std::io::ErrorKind::WouldBlock => { self.inner = Some(io);
self.inner = Some((io, buf)); Ok(Async::NotReady)
Ok(Async::NotReady)
},
_ => Err(e)
}, },
Err(e) => Err(e),
} }
} }
} }

View File

@ -180,8 +180,12 @@ impl<T: AsyncRead + AsyncWrite> AsyncWrite for Transport<T> {
} }
impl<T: AsyncRead + AsyncWrite + Peek> Peek for Transport<T> { impl<T: AsyncRead + AsyncWrite + Peek> Peek for Transport<T> {
fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn poll_peek(&mut self) -> Poll<usize, io::Error> {
self.sense_err(|io| io.peek(buf)) self.sense_err(|io| io.poll_peek())
}
fn peeked(&self) -> &[u8] {
self.0.peeked()
} }
} }

View File

@ -12,7 +12,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
use tower_service::NewService; use tower_service::NewService;
use tower_h2; use tower_h2;
use connection::{Connection, PeekFuture}; use connection::{Connection, Peek};
use ctx::Proxy as ProxyCtx; use ctx::Proxy as ProxyCtx;
use ctx::transport::{Server as ServerCtx}; use ctx::transport::{Server as ServerCtx};
use drain; use drain;
@ -131,16 +131,15 @@ where
} }
// try to sniff protocol // try to sniff protocol
let sniff = [0u8; 32];
let h1 = self.h1.clone(); let h1 = self.h1.clone();
let h2 = self.h2.clone(); let h2 = self.h2.clone();
let tcp = self.tcp.clone(); let tcp = self.tcp.clone();
let new_service = self.new_service.clone(); let new_service = self.new_service.clone();
let drain_signal = self.drain_signal.clone(); let drain_signal = self.drain_signal.clone();
let fut = PeekFuture::new(io, sniff) let fut = io.peek()
.map_err(|e| debug!("peek error: {}", e)) .map_err(|e| debug!("peek error: {}", e))
.and_then(move |(io, sniff, n)| -> Box<Future<Item=(), Error=()>> { .and_then(move |io| -> Box<Future<Item=(), Error=()>> {
if let Some(proto) = Protocol::detect(&sniff[..n]) { if let Some(proto) = Protocol::detect(io.peeked()) {
match proto { match proto {
Protocol::Http1 => { Protocol::Http1 => {
trace!("transparency detected HTTP/1"); trace!("transparency detected HTTP/1");

View File

@ -69,7 +69,7 @@ impl Proxy {
let connect = self.sensors.connect(c, &client_ctx); let connect = self.sensors.connect(c, &client_ctx);
let fut = connect.connect() let fut = connect.connect()
.map_err(|e| debug!("tcp connect error: {:?}", e)) .map_err(move |e| error!("tcp connect error to {}: {:?}", orig_dst, e))
.and_then(move |tcp_out| { .and_then(move |tcp_out| {
Duplex::new(tcp_in, tcp_out) Duplex::new(tcp_in, tcp_out)
.map_err(|e| error!("tcp duplex error: {}", e)) .map_err(|e| error!("tcp duplex error: {}", e))