proxy: change peek to use reads for eventual support of TLS (#901)

This commit is contained in:
Sean McArthur 2018-05-08 18:19:12 -07:00 committed by GitHub
parent 50cb2f84db
commit 011d2541eb
4 changed files with 108 additions and 65 deletions

View File

@ -1,6 +1,7 @@
use bytes::Buf;
use bytes::{Buf, BytesMut};
use futures::*;
use std;
use std::cmp;
use std::io;
use std::net::SocketAddr;
use tokio_core::net::{TcpListener, TcpStreamNew, TcpStream};
@ -32,19 +33,48 @@ pub struct Connecting(TcpStreamNew);
/// socket to reduce the chance of TLS protections being accidentally
/// subverted.
#[derive(Debug)]
pub enum Connection {
pub struct Connection {
io: Io,
/// This buffer gets filled up when "peeking" bytes on this Connection.
///
/// This is used instead of MSG_PEEK in order to support TLS streams.
///
/// When calling `read`, it's important to consume bytes from this buffer
/// before calling `io.read`.
peek_buf: BytesMut,
}
#[derive(Debug)]
enum Io {
Plain(PlaintextSocket),
}
/// A trait describing that a type can peek (such as MSG_PEEK).
/// A trait describing that a type can peek bytes.
pub trait Peek {
fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize>;
/// An async attempt to peek bytes of this type without consuming.
///
/// Returns number of bytes that have been peeked.
fn poll_peek(&mut self) -> Poll<usize, io::Error>;
/// Returns a reference to the bytes that have been peeked.
// Instead of passing a buffer into `peek()`, the bytes are kept in
// a buffer owned by the `Peek` type. This allows looking at the
// peeked bytes cheaply, instead of needing to copy into a new
// buffer.
fn peeked(&self) -> &[u8];
/// A `Future` around `poll_peek`, returning this type instead.
fn peek(self) -> PeekFuture<Self> where Self: Sized {
PeekFuture {
inner: Some(self),
}
}
}
/// A future of when some `Peek` fulfills with some bytes.
#[derive(Debug)]
pub struct PeekFuture<T, B> {
inner: Option<(T, B)>,
pub struct PeekFuture<T> {
inner: Option<T>,
}
// ===== impl BoundPort =====
@ -85,7 +115,7 @@ impl BoundPort {
// libraries don't have the necessary API for that, so just
// do it here.
set_nodelay_or_warn(&socket);
f(b, (Connection::Plain(socket), remote_addr))
f(b, (Connection::plain(socket), remote_addr))
});
Box::new(fut.map(|_| ()))
@ -101,13 +131,21 @@ impl Future for Connecting {
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let socket = try_ready!(self.0.poll());
set_nodelay_or_warn(&socket);
Ok(Async::Ready(Connection::Plain(socket)))
Ok(Async::Ready(Connection::plain(socket)))
}
}
// ===== impl Connection =====
impl Connection {
/// A constructor of `Connection` with a plain text TCP socket.
pub fn plain(socket: PlaintextSocket) -> Self {
Connection {
io: Io::Plain(socket),
peek_buf: BytesMut::new(),
}
}
pub fn original_dst_addr<T: GetOriginalDst>(&self, get: &T) -> Option<SocketAddr> {
get.get_original_dst(self.socket())
}
@ -123,46 +161,55 @@ impl Connection {
// underlying socket should be exposed by its own minimal accessor function
// as is done above.
fn socket(&self) -> &PlaintextSocket {
match self {
&Connection::Plain(ref socket) => socket
match self.io {
Io::Plain(ref socket) => socket
}
}
}
impl io::Read for Connection {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
use self::Connection::*;
// Check the length only once, since looking as the length
// of a BytesMut isn't as cheap as the length of a &[u8].
let peeked_len = self.peek_buf.len();
match *self {
Plain(ref mut t) => t.read(buf),
if peeked_len == 0 {
match self.io {
Io::Plain(ref mut t) => t.read(buf),
}
} else {
let len = cmp::min(buf.len(), peeked_len);
buf[..len].copy_from_slice(&self.peek_buf.as_ref()[..len]);
self.peek_buf.advance(len);
// If we've finally emptied the peek_buf, drop it so we don't
// hold onto the allocated memory any longer. We won't peek
// again.
if peeked_len == len {
self.peek_buf = BytesMut::new();
}
Ok(len)
}
}
}
impl AsyncRead for Connection {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
use self::Connection::*;
match *self {
Plain(ref t) => t.prepare_uninitialized_buffer(buf),
match self.io {
Io::Plain(ref t) => t.prepare_uninitialized_buffer(buf),
}
}
}
impl io::Write for Connection {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
use self::Connection::*;
match *self {
Plain(ref mut t) => t.write(buf),
match self.io {
Io::Plain(ref mut t) => t.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
use self::Connection::*;
match *self {
Plain(ref mut t) => t.flush(),
match self.io {
Io::Plain(ref mut t) => t.flush(),
}
}
}
@ -170,10 +217,8 @@ impl io::Write for Connection {
impl AsyncWrite for Connection {
fn shutdown(&mut self) -> Poll<(), io::Error> {
use std::net::Shutdown;
use self::Connection::*;
match *self {
Plain(ref mut t) => {
match self.io {
Io::Plain(ref mut t) => {
try_ready!(AsyncWrite::shutdown(t));
// TCP shutdown the write side.
//
@ -187,49 +232,44 @@ impl AsyncWrite for Connection {
}
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
use self::Connection::*;
match *self {
Plain(ref mut t) => t.write_buf(buf),
match self.io {
Io::Plain(ref mut t) => t.write_buf(buf),
}
}
}
impl Peek for Connection {
fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
use self::Connection::*;
match *self {
Plain(ref mut t) => t.peek(buf),
fn poll_peek(&mut self) -> Poll<usize, io::Error> {
if self.peek_buf.is_empty() {
self.peek_buf.reserve(8192);
match self.io {
Io::Plain(ref mut t) => t.read_buf(&mut self.peek_buf),
}
} else {
Ok(Async::Ready(self.peek_buf.len()))
}
}
fn peeked(&self) -> &[u8] {
self.peek_buf.as_ref()
}
}
// impl PeekFuture
impl<T: Peek, B: AsMut<[u8]>> PeekFuture<T, B> {
pub fn new(io: T, buf: B) -> Self {
PeekFuture {
inner: Some((io, buf)),
}
}
}
impl<T: Peek, B: AsMut<[u8]>> Future for PeekFuture<T, B> {
type Item = (T, B, usize);
impl<T: Peek> Future for PeekFuture<T> {
type Item = T;
type Error = std::io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let (mut io, mut buf) = self.inner.take().expect("polled after completed");
match io.peek(buf.as_mut()) {
Ok(n) => Ok(Async::Ready((io, buf, n))),
Err(e) => match e.kind() {
std::io::ErrorKind::WouldBlock => {
self.inner = Some((io, buf));
Ok(Async::NotReady)
},
_ => Err(e)
let mut io = self.inner.take().expect("polled after completed");
match io.poll_peek() {
Ok(Async::Ready(_)) => Ok(Async::Ready(io)),
Ok(Async::NotReady) => {
self.inner = Some(io);
Ok(Async::NotReady)
},
Err(e) => Err(e),
}
}
}

View File

@ -180,8 +180,12 @@ impl<T: AsyncRead + AsyncWrite> AsyncWrite for Transport<T> {
}
impl<T: AsyncRead + AsyncWrite + Peek> Peek for Transport<T> {
fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.sense_err(|io| io.peek(buf))
fn poll_peek(&mut self) -> Poll<usize, io::Error> {
self.sense_err(|io| io.poll_peek())
}
fn peeked(&self) -> &[u8] {
self.0.peeked()
}
}

View File

@ -12,7 +12,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
use tower_service::NewService;
use tower_h2;
use connection::{Connection, PeekFuture};
use connection::{Connection, Peek};
use ctx::Proxy as ProxyCtx;
use ctx::transport::{Server as ServerCtx};
use drain;
@ -131,16 +131,15 @@ where
}
// try to sniff protocol
let sniff = [0u8; 32];
let h1 = self.h1.clone();
let h2 = self.h2.clone();
let tcp = self.tcp.clone();
let new_service = self.new_service.clone();
let drain_signal = self.drain_signal.clone();
let fut = PeekFuture::new(io, sniff)
let fut = io.peek()
.map_err(|e| debug!("peek error: {}", e))
.and_then(move |(io, sniff, n)| -> Box<Future<Item=(), Error=()>> {
if let Some(proto) = Protocol::detect(&sniff[..n]) {
.and_then(move |io| -> Box<Future<Item=(), Error=()>> {
if let Some(proto) = Protocol::detect(io.peeked()) {
match proto {
Protocol::Http1 => {
trace!("transparency detected HTTP/1");

View File

@ -69,7 +69,7 @@ impl Proxy {
let connect = self.sensors.connect(c, &client_ctx);
let fut = connect.connect()
.map_err(|e| debug!("tcp connect error: {:?}", e))
.map_err(move |e| error!("tcp connect error to {}: {:?}", orig_dst, e))
.and_then(move |tcp_out| {
Duplex::new(tcp_in, tcp_out)
.map_err(|e| error!("tcp duplex error: {}", e))