diff --git a/proxy/src/connection.rs b/proxy/src/connection.rs index 2f38c5a90..10ccbbed7 100644 --- a/proxy/src/connection.rs +++ b/proxy/src/connection.rs @@ -252,6 +252,9 @@ impl Connection { impl io::Read for Connection { fn read(&mut self, buf: &mut [u8]) -> io::Result { + // TODO: Eliminate the duplication between this and + // `transport::prefixed::Prefixed`. + // Check the length only once, since looking as the length // of a BytesMut isn't as cheap as the length of a &[u8]. let peeked_len = self.peek_buf.len(); @@ -266,7 +269,7 @@ impl io::Read for Connection { // hold onto the allocated memory any longer. We won't peek // again. if peeked_len == len { - self.peek_buf = BytesMut::new(); + self.peek_buf = Default::default(); } Ok(len) } diff --git a/proxy/src/transport/mod.rs b/proxy/src/transport/mod.rs index abb9243c7..21283a8d0 100644 --- a/proxy/src/transport/mod.rs +++ b/proxy/src/transport/mod.rs @@ -1,6 +1,7 @@ mod connect; mod addr_info; mod io; +mod prefixed; pub mod tls; pub use self::connect::{ diff --git a/proxy/src/transport/prefixed.rs b/proxy/src/transport/prefixed.rs new file mode 100644 index 000000000..1e7cb52ac --- /dev/null +++ b/proxy/src/transport/prefixed.rs @@ -0,0 +1,92 @@ +#![allow(dead_code)] // TODO: Actually use this. + +use std::{cmp, fmt::Debug, io, net::SocketAddr}; + +use super::io::internal::Io; +use bytes::{Buf, Bytes}; +use tokio::prelude::*; +use AddrInfo; + +/// A TcpStream where the initial reads will be served from `prefix`. +#[derive(Debug)] +pub struct Prefixed { + prefix: Bytes, + io: S, +} + +impl Prefixed { + pub fn new(prefix: Bytes, io: S) -> Self { + Self { prefix, io } + } +} + +impl io::Read for Prefixed where S: Debug + io::Read { + fn read(&mut self, buf: &mut [u8]) -> Result { + // Check the length only once, since looking as the length + // of a Bytes isn't as cheap as the length of a &[u8]. + let peeked_len = self.prefix.len(); + + if peeked_len == 0 { + self.io.read(buf) + } else { + let len = cmp::min(buf.len(), peeked_len); + buf[..len].copy_from_slice(&self.prefix.as_ref()[..len]); + self.prefix.advance(len); + // If we've finally emptied the peek_buf, drop it so we don't + // hold onto the allocated memory any longer. We won't peek + // again. + if peeked_len == len { + self.prefix = Default::default(); + } + Ok(len) + } + } +} + +impl AsyncRead for Prefixed where S: AsyncRead + Debug { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.io.prepare_uninitialized_buffer(buf) + } +} + +impl io::Write for Prefixed where S: Debug + io::Write { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.io.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.io.flush() + } +} + +impl AsyncWrite for Prefixed where S: AsyncWrite + Debug { + fn shutdown(&mut self) -> Result, io::Error> { + self.io.shutdown() + } + + fn write_buf(&mut self, buf: &mut B) -> Poll + where Self: Sized + { + self.io.write_buf(buf) + } +} + +impl AddrInfo for Prefixed where S: AddrInfo { + fn local_addr(&self) -> Result { + self.io.local_addr() + } + + fn get_original_dst(&self) -> Option { + self.io.get_original_dst() + } +} + +impl Io for Prefixed where S: Io { + fn shutdown_write(&mut self) -> Result<(), io::Error> { + self.io.shutdown_write() + } + + fn write_buf_erased(&mut self, buf: &mut Buf) -> Result, io::Error> { + self.io.write_buf_erased(buf) + } +}