From 75034ef09d444dcc276cc470393c14103276c73b Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Mon, 25 Jun 2018 12:12:53 -1000 Subject: [PATCH] Proxy: Add `transport::prefixed::Prefixed`. (#1196) Copy most of the implementation of `connection::Connection` to create a way to prefix a `TcpStream` with some previously-read bytes. This will allow us to read and parse a TLS ClientHello message to see if it is intended for the proxy to process, and then "rewind" and feed it back into the TLS implementation if so. This must be in the `transport` submodule in order for it to implement the private `Io` trait. Signed-off-by: Brian Smith --- proxy/src/connection.rs | 5 +- proxy/src/transport/mod.rs | 1 + proxy/src/transport/prefixed.rs | 92 +++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 proxy/src/transport/prefixed.rs diff --git a/proxy/src/connection.rs b/proxy/src/connection.rs index 2f38c5a90..10ccbbed7 100644 --- a/proxy/src/connection.rs +++ b/proxy/src/connection.rs @@ -252,6 +252,9 @@ impl Connection { impl io::Read for Connection { fn read(&mut self, buf: &mut [u8]) -> io::Result { + // TODO: Eliminate the duplication between this and + // `transport::prefixed::Prefixed`. + // Check the length only once, since looking as the length // of a BytesMut isn't as cheap as the length of a &[u8]. let peeked_len = self.peek_buf.len(); @@ -266,7 +269,7 @@ impl io::Read for Connection { // hold onto the allocated memory any longer. We won't peek // again. if peeked_len == len { - self.peek_buf = BytesMut::new(); + self.peek_buf = Default::default(); } Ok(len) } diff --git a/proxy/src/transport/mod.rs b/proxy/src/transport/mod.rs index abb9243c7..21283a8d0 100644 --- a/proxy/src/transport/mod.rs +++ b/proxy/src/transport/mod.rs @@ -1,6 +1,7 @@ mod connect; mod addr_info; mod io; +mod prefixed; pub mod tls; pub use self::connect::{ diff --git a/proxy/src/transport/prefixed.rs b/proxy/src/transport/prefixed.rs new file mode 100644 index 000000000..1e7cb52ac --- /dev/null +++ b/proxy/src/transport/prefixed.rs @@ -0,0 +1,92 @@ +#![allow(dead_code)] // TODO: Actually use this. + +use std::{cmp, fmt::Debug, io, net::SocketAddr}; + +use super::io::internal::Io; +use bytes::{Buf, Bytes}; +use tokio::prelude::*; +use AddrInfo; + +/// A TcpStream where the initial reads will be served from `prefix`. +#[derive(Debug)] +pub struct Prefixed { + prefix: Bytes, + io: S, +} + +impl Prefixed { + pub fn new(prefix: Bytes, io: S) -> Self { + Self { prefix, io } + } +} + +impl io::Read for Prefixed where S: Debug + io::Read { + fn read(&mut self, buf: &mut [u8]) -> Result { + // Check the length only once, since looking as the length + // of a Bytes isn't as cheap as the length of a &[u8]. + let peeked_len = self.prefix.len(); + + if peeked_len == 0 { + self.io.read(buf) + } else { + let len = cmp::min(buf.len(), peeked_len); + buf[..len].copy_from_slice(&self.prefix.as_ref()[..len]); + self.prefix.advance(len); + // If we've finally emptied the peek_buf, drop it so we don't + // hold onto the allocated memory any longer. We won't peek + // again. + if peeked_len == len { + self.prefix = Default::default(); + } + Ok(len) + } + } +} + +impl AsyncRead for Prefixed where S: AsyncRead + Debug { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.io.prepare_uninitialized_buffer(buf) + } +} + +impl io::Write for Prefixed where S: Debug + io::Write { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.io.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.io.flush() + } +} + +impl AsyncWrite for Prefixed where S: AsyncWrite + Debug { + fn shutdown(&mut self) -> Result, io::Error> { + self.io.shutdown() + } + + fn write_buf(&mut self, buf: &mut B) -> Poll + where Self: Sized + { + self.io.write_buf(buf) + } +} + +impl AddrInfo for Prefixed where S: AddrInfo { + fn local_addr(&self) -> Result { + self.io.local_addr() + } + + fn get_original_dst(&self) -> Option { + self.io.get_original_dst() + } +} + +impl Io for Prefixed where S: Io { + fn shutdown_write(&mut self) -> Result<(), io::Error> { + self.io.shutdown_write() + } + + fn write_buf_erased(&mut self, buf: &mut Buf) -> Result, io::Error> { + self.io.write_buf_erased(buf) + } +}