proxy: change peek to use reads for eventual support of TLS (#901)
This commit is contained in:
parent
50cb2f84db
commit
011d2541eb
|
@ -1,6 +1,7 @@
|
|||
use bytes::Buf;
|
||||
use bytes::{Buf, BytesMut};
|
||||
use futures::*;
|
||||
use std;
|
||||
use std::cmp;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
use tokio_core::net::{TcpListener, TcpStreamNew, TcpStream};
|
||||
|
@ -32,19 +33,48 @@ pub struct Connecting(TcpStreamNew);
|
|||
/// socket to reduce the chance of TLS protections being accidentally
|
||||
/// subverted.
|
||||
#[derive(Debug)]
|
||||
pub enum Connection {
|
||||
pub struct Connection {
|
||||
io: Io,
|
||||
/// This buffer gets filled up when "peeking" bytes on this Connection.
|
||||
///
|
||||
/// This is used instead of MSG_PEEK in order to support TLS streams.
|
||||
///
|
||||
/// When calling `read`, it's important to consume bytes from this buffer
|
||||
/// before calling `io.read`.
|
||||
peek_buf: BytesMut,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Io {
|
||||
Plain(PlaintextSocket),
|
||||
}
|
||||
|
||||
/// A trait describing that a type can peek (such as MSG_PEEK).
|
||||
/// A trait describing that a type can peek bytes.
|
||||
pub trait Peek {
|
||||
fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize>;
|
||||
/// An async attempt to peek bytes of this type without consuming.
|
||||
///
|
||||
/// Returns number of bytes that have been peeked.
|
||||
fn poll_peek(&mut self) -> Poll<usize, io::Error>;
|
||||
|
||||
/// Returns a reference to the bytes that have been peeked.
|
||||
// Instead of passing a buffer into `peek()`, the bytes are kept in
|
||||
// a buffer owned by the `Peek` type. This allows looking at the
|
||||
// peeked bytes cheaply, instead of needing to copy into a new
|
||||
// buffer.
|
||||
fn peeked(&self) -> &[u8];
|
||||
|
||||
/// A `Future` around `poll_peek`, returning this type instead.
|
||||
fn peek(self) -> PeekFuture<Self> where Self: Sized {
|
||||
PeekFuture {
|
||||
inner: Some(self),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A future of when some `Peek` fulfills with some bytes.
|
||||
#[derive(Debug)]
|
||||
pub struct PeekFuture<T, B> {
|
||||
inner: Option<(T, B)>,
|
||||
pub struct PeekFuture<T> {
|
||||
inner: Option<T>,
|
||||
}
|
||||
|
||||
// ===== impl BoundPort =====
|
||||
|
@ -85,7 +115,7 @@ impl BoundPort {
|
|||
// libraries don't have the necessary API for that, so just
|
||||
// do it here.
|
||||
set_nodelay_or_warn(&socket);
|
||||
f(b, (Connection::Plain(socket), remote_addr))
|
||||
f(b, (Connection::plain(socket), remote_addr))
|
||||
});
|
||||
|
||||
Box::new(fut.map(|_| ()))
|
||||
|
@ -101,13 +131,21 @@ impl Future for Connecting {
|
|||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
let socket = try_ready!(self.0.poll());
|
||||
set_nodelay_or_warn(&socket);
|
||||
Ok(Async::Ready(Connection::Plain(socket)))
|
||||
Ok(Async::Ready(Connection::plain(socket)))
|
||||
}
|
||||
}
|
||||
|
||||
// ===== impl Connection =====
|
||||
|
||||
impl Connection {
|
||||
/// A constructor of `Connection` with a plain text TCP socket.
|
||||
pub fn plain(socket: PlaintextSocket) -> Self {
|
||||
Connection {
|
||||
io: Io::Plain(socket),
|
||||
peek_buf: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn original_dst_addr<T: GetOriginalDst>(&self, get: &T) -> Option<SocketAddr> {
|
||||
get.get_original_dst(self.socket())
|
||||
}
|
||||
|
@ -123,46 +161,55 @@ impl Connection {
|
|||
// underlying socket should be exposed by its own minimal accessor function
|
||||
// as is done above.
|
||||
fn socket(&self) -> &PlaintextSocket {
|
||||
match self {
|
||||
&Connection::Plain(ref socket) => socket
|
||||
match self.io {
|
||||
Io::Plain(ref socket) => socket
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Read for Connection {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
use self::Connection::*;
|
||||
// Check the length only once, since looking as the length
|
||||
// of a BytesMut isn't as cheap as the length of a &[u8].
|
||||
let peeked_len = self.peek_buf.len();
|
||||
|
||||
match *self {
|
||||
Plain(ref mut t) => t.read(buf),
|
||||
if peeked_len == 0 {
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => t.read(buf),
|
||||
}
|
||||
} else {
|
||||
let len = cmp::min(buf.len(), peeked_len);
|
||||
buf[..len].copy_from_slice(&self.peek_buf.as_ref()[..len]);
|
||||
self.peek_buf.advance(len);
|
||||
// If we've finally emptied the peek_buf, drop it so we don't
|
||||
// hold onto the allocated memory any longer. We won't peek
|
||||
// again.
|
||||
if peeked_len == len {
|
||||
self.peek_buf = BytesMut::new();
|
||||
}
|
||||
Ok(len)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for Connection {
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||
use self::Connection::*;
|
||||
|
||||
match *self {
|
||||
Plain(ref t) => t.prepare_uninitialized_buffer(buf),
|
||||
match self.io {
|
||||
Io::Plain(ref t) => t.prepare_uninitialized_buffer(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for Connection {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
use self::Connection::*;
|
||||
|
||||
match *self {
|
||||
Plain(ref mut t) => t.write(buf),
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => t.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
use self::Connection::*;
|
||||
|
||||
match *self {
|
||||
Plain(ref mut t) => t.flush(),
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => t.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -170,10 +217,8 @@ impl io::Write for Connection {
|
|||
impl AsyncWrite for Connection {
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
use std::net::Shutdown;
|
||||
use self::Connection::*;
|
||||
|
||||
match *self {
|
||||
Plain(ref mut t) => {
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => {
|
||||
try_ready!(AsyncWrite::shutdown(t));
|
||||
// TCP shutdown the write side.
|
||||
//
|
||||
|
@ -187,49 +232,44 @@ impl AsyncWrite for Connection {
|
|||
}
|
||||
|
||||
fn write_buf<B: Buf>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
|
||||
use self::Connection::*;
|
||||
|
||||
match *self {
|
||||
Plain(ref mut t) => t.write_buf(buf),
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => t.write_buf(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Peek for Connection {
|
||||
fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
use self::Connection::*;
|
||||
|
||||
match *self {
|
||||
Plain(ref mut t) => t.peek(buf),
|
||||
fn poll_peek(&mut self) -> Poll<usize, io::Error> {
|
||||
if self.peek_buf.is_empty() {
|
||||
self.peek_buf.reserve(8192);
|
||||
match self.io {
|
||||
Io::Plain(ref mut t) => t.read_buf(&mut self.peek_buf),
|
||||
}
|
||||
} else {
|
||||
Ok(Async::Ready(self.peek_buf.len()))
|
||||
}
|
||||
}
|
||||
|
||||
fn peeked(&self) -> &[u8] {
|
||||
self.peek_buf.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
// impl PeekFuture
|
||||
|
||||
impl<T: Peek, B: AsMut<[u8]>> PeekFuture<T, B> {
|
||||
pub fn new(io: T, buf: B) -> Self {
|
||||
PeekFuture {
|
||||
inner: Some((io, buf)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Peek, B: AsMut<[u8]>> Future for PeekFuture<T, B> {
|
||||
type Item = (T, B, usize);
|
||||
impl<T: Peek> Future for PeekFuture<T> {
|
||||
type Item = T;
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
let (mut io, mut buf) = self.inner.take().expect("polled after completed");
|
||||
match io.peek(buf.as_mut()) {
|
||||
Ok(n) => Ok(Async::Ready((io, buf, n))),
|
||||
Err(e) => match e.kind() {
|
||||
std::io::ErrorKind::WouldBlock => {
|
||||
self.inner = Some((io, buf));
|
||||
Ok(Async::NotReady)
|
||||
},
|
||||
_ => Err(e)
|
||||
let mut io = self.inner.take().expect("polled after completed");
|
||||
match io.poll_peek() {
|
||||
Ok(Async::Ready(_)) => Ok(Async::Ready(io)),
|
||||
Ok(Async::NotReady) => {
|
||||
self.inner = Some(io);
|
||||
Ok(Async::NotReady)
|
||||
},
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -180,8 +180,12 @@ impl<T: AsyncRead + AsyncWrite> AsyncWrite for Transport<T> {
|
|||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Peek> Peek for Transport<T> {
|
||||
fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.sense_err(|io| io.peek(buf))
|
||||
fn poll_peek(&mut self) -> Poll<usize, io::Error> {
|
||||
self.sense_err(|io| io.poll_peek())
|
||||
}
|
||||
|
||||
fn peeked(&self) -> &[u8] {
|
||||
self.0.peeked()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
|
|||
use tower_service::NewService;
|
||||
use tower_h2;
|
||||
|
||||
use connection::{Connection, PeekFuture};
|
||||
use connection::{Connection, Peek};
|
||||
use ctx::Proxy as ProxyCtx;
|
||||
use ctx::transport::{Server as ServerCtx};
|
||||
use drain;
|
||||
|
@ -131,16 +131,15 @@ where
|
|||
}
|
||||
|
||||
// try to sniff protocol
|
||||
let sniff = [0u8; 32];
|
||||
let h1 = self.h1.clone();
|
||||
let h2 = self.h2.clone();
|
||||
let tcp = self.tcp.clone();
|
||||
let new_service = self.new_service.clone();
|
||||
let drain_signal = self.drain_signal.clone();
|
||||
let fut = PeekFuture::new(io, sniff)
|
||||
let fut = io.peek()
|
||||
.map_err(|e| debug!("peek error: {}", e))
|
||||
.and_then(move |(io, sniff, n)| -> Box<Future<Item=(), Error=()>> {
|
||||
if let Some(proto) = Protocol::detect(&sniff[..n]) {
|
||||
.and_then(move |io| -> Box<Future<Item=(), Error=()>> {
|
||||
if let Some(proto) = Protocol::detect(io.peeked()) {
|
||||
match proto {
|
||||
Protocol::Http1 => {
|
||||
trace!("transparency detected HTTP/1");
|
||||
|
|
|
@ -69,7 +69,7 @@ impl Proxy {
|
|||
let connect = self.sensors.connect(c, &client_ctx);
|
||||
|
||||
let fut = connect.connect()
|
||||
.map_err(|e| debug!("tcp connect error: {:?}", e))
|
||||
.map_err(move |e| error!("tcp connect error to {}: {:?}", orig_dst, e))
|
||||
.and_then(move |tcp_out| {
|
||||
Duplex::new(tcp_in, tcp_out)
|
||||
.map_err(|e| error!("tcp duplex error: {}", e))
|
||||
|
|
Loading…
Reference in New Issue